版权声明:本套技术专栏是做者(秦凯新)平时工做的总结和升华,经过从真实商业环境抽取案例进行总结和分享,并给出商业应用的调优建议和集群环境容量规划等内容,请持续关注本套博客。QQ邮箱地址:1120746959@qq.com,若有任何技术交流,可随时联系。sql
MutableAggregationBuffer是一个数组,这里咱们取 buffer.getString(0)。数据库
把传进来的字符串进行追加到buffer.getString(0)中。express
class GroupConcatDistinctUDAF extends UserDefinedAggregateFunction {
输入数据类型
override def inputSchema: StructType = StructType(StructField("cityInfo", StringType) :: Nil)
缓冲数据类型
override def bufferSchema: StructType = StructType(StructField("bufferCityInfo", StringType) :: Nil)
输出数据类型
override def dataType: DataType = StringType
一致性校验
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)= ""
}
/**
* 更新
* 能够认为是,一个一个地将组内的字段值传递进来
* 实现拼接的逻辑
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 缓冲中的已经拼接过的城市信息串
var bufferCityInfo = buffer.getString(0)
// 刚刚传递进来的某个城市信息
val cityInfo = input.getString(0)
// 在这里要实现去重的逻辑
// 判断:以前没有拼接过某个城市信息,那么这里才能够接下去拼接新的城市信息
if(!bufferCityInfo.contains(cityInfo)) {
if("".equals(bufferCityInfo))
bufferCityInfo += cityInfo
else {
// 好比1:北京
// 1:北京,2:上海
bufferCityInfo += "," + cityInfo
}
buffer.update(0, bufferCityInfo)
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
var bufferCityInfo1 = buffer1.getString(0);
val bufferCityInfo2 = buffer2.getString(0);
for(cityInfo <- bufferCityInfo2.split(",")) {
if(!bufferCityInfo1.contains(cityInfo)) {
if("".equals(bufferCityInfo1)) {
bufferCityInfo1 += cityInfo;
} else {
bufferCityInfo1 += "," + cityInfo;
}
}
}
buffer1.update(0, bufferCityInfo1);
}
override def evaluate(buffer: Row): Any = {
buffer.getString(0)
}
}
复制代码
分析数据apache
第一列为user_id,第二列为item_id,第三列为score
0162381440670851711,4,7.0
0162381440670851711,11,4.0
0162381440670851711,32,1.0
0162381440670851711,176,27.0
0162381440670851711,183,11.0
0162381440670851711,184,5.0
0162381440670851711,207,9.0
0162381440670851711,256,3.0
0162381440670851711,258,4.0
0162381440670851711,259,16.0
0162381440670851711,260,8.0
0162381440670851711,261,18.0
0162381440670851711,301,1.0
复制代码
一、inputSchemajson
定义输入数据的Schema,要求类型是StructType,它的参数是由StructField类型构成的数组。好比这里要定义score列的Schema,首先使用StructField声明score列的名字score_column,数据类型为DoubleType。这里输入只有score这一列,因此StructField构成的数组只有一个元素。数组
override def inputSchema: StructType = StructType(StructField("score_column",DoubleType)::Nil)
复制代码
二、bufferSchema缓存
计算score的平均值时,须要用到score的总和sum以及score的总个数count这样的中间数据,那么就使用bufferSchema来定义它们。安全
override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::StructField("count",LongType)::Nil)
复制代码
三、dataTypeapp
咱们须要对自定义聚合函数的最终数据类型进行说明,使用dataType函数。好比计算出的平均score是Double类型。dom
override def dataType: DataType = DoubleType
复制代码
四、deterministic
deterministic函数用于对输入数据进行一致性检验,是一个布尔值,当为true时,表示对于一样的输入会获得一样的输出。由于对于一样的score输入,确定要获得相同的score平均值,因此定义为true。
override def deterministic: Boolean = true
复制代码
五、initialize
initialize用户初始化缓存数据。好比score的缓存数据有两个:sum和count,须要初始化为sum=0.0和count=0L,第一个初始化为Double类型,第二个初始化为长整型。
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//sum=0.0
buffer(0)=0.0
//count=0
buffer(1)=0L
}
复制代码
当有新的输入数据时,update用户更新缓存变量。好比这里当有新的score输入时,须要将它的值更新变量sum中,并将count加1
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//输入非空
if(!input.isNullAt(0)){
//sum=sum+输入的score
buffer(0)=buffer.getDouble(0)+input.getDouble(0)
//count=count+1
buffer(1)=buffer.getLong(1)+1
}
}
复制代码
七、merge
merge将更新的缓存变量存入到缓存中
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getDouble(0)+buffer2.getDouble(0)
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
复制代码
八、evaluate
evaluate是一个计算方法,用于计算咱们的最终结果。好比这里用于计算平均得分average(score)=sum(score)/count(score)
override def evaluate(buffer: Row): Double = buffer.getDouble(0)/buffer.getLong(1)
复制代码
Data用于存储itemdata.data数据,Average用于存储计算score平均值的中间数据,须要注意的是Average的参数sum和count都要声明为变量var。
case class Data(user_id: String, item_id: String, score: Double)
case class Average(var sum: Double,var count: Long)
复制代码
具体源码
聚合函数 toColumn.name("average_score")
使用聚合函数 dataDS.select(averageScore).show()
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
/**
* 类型安全自定义聚合函数
*/
object TypeSafeMyAverageTest {
case class Data(user_id: String, item_id: String, score: Double)
case class Average(var sum: Double,var count: Long)
object SafeMyAverage extends Aggregator[Data, Average, Double] {
zero至关于1中的initialize初始化函数,初始化存储中间数据的Average
override def zero: Average = Average(0.0D, 0L)
reduce函数至关于1中的update函数,当有新的数据a时,更新中间数据b
override def reduce(b: Average, a: Data): Average = {
b.sum += a.score
b.count += 1L
b
}
override def merge(b1: Average, b2: Average): Average = {
b1.sum+=b2.sum
b1.count+= b2.count
b1
}
override def finish(reduction: Average): Double = reduction.sum / reduction.count
缓冲数据编码方式
override def bufferEncoder: Encoder[Average] = Encoders.product
最终数据输出编码方式
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
def main(args: Array[String]): Unit = {
//建立Spark SQL切入点
val spark = SparkSession.builder().master("local").appName("My-Average").getOrCreate()
//读取HDFS文件系统数据itemdata.data生成RDD
val rdd = spark.sparkContext.textFile("hdfs://192.168.189.21:8020/input/mahout-demo/itemdata.data")
//RDD转化成DataSet
import spark.implicits._
val dataDS =rdd.map(_.split(",")).map(d => Data(d(0), d(1), d(2).toDouble)).toDS()
//自定义聚合函数
val averageScore = SafeMyAverage.toColumn.name("average_score")
dataDS.select(averageScore).show()
}
}
复制代码
// 任务的执行ID,用户惟一标示运行后的结果,用在MySQL数据库中
val taskUUID = UUID.randomUUID().toString
// 构建Spark上下文
val sparkConf = new SparkConf().setAppName("SessionAnalyzer").setMaster("local[*]")
// 建立Spark客户端
val spark = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
val sc = spark.sparkContext
复制代码
// 注册自定义函数
spark.udf.register("concat_long_string", (v1: Long, v2: String, split: String) => v1.toString + split + v2)
spark.udf.register("get_json_object", (json: String, field: String) => {
val jsonObject = JSONObject.fromObject(json);
jsonObject.getString(field)
})
spark.udf.register("group_concat_distinct", new GroupConcatDistinctUDAF())
// 获取任务日期参数
val startDate = ParamUtils.getParam(taskParam, Constants.PARAM_START_DATE)
val endDate = ParamUtils.getParam(taskParam, Constants.PARAM_END_DATE)
复制代码
val cityid2clickActionRDD = getcityid2ClickActionRDDByDate(spark, startDate, endDate)
def getcityid2ClickActionRDDByDate(spark: SparkSession, startDate: String, endDate: String): RDD[(Long, Row)] = {
// 从user_visit_action中,查询用户访问行为数据
// 第一个限定:click_product_id,限定为不为空的访问行为,那么就表明着点击行为
// 第二个限定:在用户指定的日期范围内的数据
val sql =
"SELECT " +
"city_id," +
"click_product_id " +
"FROM user_visit_action " +
"WHERE click_product_id IS NOT NULL and click_product_id != -1L " +
"AND date>='" + startDate + "' " +
"AND date<='" + endDate + "'"
val clickActionDF = spark.sql(sql)
//(cityid, row)
clickActionDF.rdd.map(item => (item.getAs[Long]("city_id"), item))
}
复制代码
def getcityid2CityInfoRDD(spark: SparkSession): RDD[(Long, Row)] = {
val cityInfo = Array((0L, "北京", "华北"), (1L, "上海", "华东"), (2L, "南京", "华东"), (3L, "广州", "华南"), (4L, "三亚", "华南"), (5L, "武汉", "华中"), (6L, "长沙", "华中"), (7L, "西安", "西北"), (8L, "成都", "西南"), (9L, "哈尔滨", "东北"))
import spark.implicits._
val cityInfoDF = spark.sparkContext.makeRDD(cityInfo).toDF("city_id", "city_name", "area")
cityInfoDF.rdd.map(item => (item.getAs[Long]("city_id"), item))
}
// 使用(city_id , 城市信息)
val cityid2cityInfoRDD = getcityid2CityInfoRDD(spark)
复制代码
// 将点击行为cityid2clickActionRDD和城市信息cityid2cityInfoRDD进行Join关联
// tmp_click_product_basic
generateTempClickProductBasicTable(spark, cityid2clickActionRDD, cityid2cityInfoRDD)
def generateTempClickProductBasicTable(spark: SparkSession, cityid2clickActionRDD: RDD[(Long, Row)], cityid2cityInfoRDD: RDD[(Long, Row)]) {
// 执行join操做,进行点击行为数据和城市数据的关联
val joinedRDD = cityid2clickActionRDD.join(cityid2cityInfoRDD)
// 将上面的JavaPairRDD,转换成一个JavaRDD<Row>(才能将RDD转换为DataFrame)
val mappedRDD = joinedRDD.map { case (cityid, (action, cityinfo)) =>
val productid = action.getLong(1)
//action.getAs[String]("aera")
val cityName = cityinfo.getString(1)
val area = cityinfo.getString(2)
(cityid, cityName, area, productid)
}
// 1 北京
// 2 上海
// 1 北京
// group by area,product_id
// 1:北京,2:上海
// 两个函数
// UDF:concat2(),将两个字段拼接起来,用指定的分隔符
// UDAF:group_concat_distinct(),将一个分组中的多个字段值,用逗号拼接起来,同时进行去重
import spark.implicits._
val df = mappedRDD.toDF("city_id", "city_name", "area", "product_id")
// 为df建立临时表
df.createOrReplaceTempView("tmp_click_product_basic")
复制代码
generateTempAreaPrdocutClickCountTable(spark)
def generateTempAreaPrdocutClickCountTable(spark: SparkSession) {
// 按照area和product_id两个字段进行分组
// 计算出各区域各商品的点击次数
// 能够获取到每一个area下的每一个product_id的城市信息拼接起来的串
val sql = "SELECT " +
"area," +
"product_id," +
"count(*) click_count, " +
"group_concat_distinct(concat_long_string(city_id,city_name,':')) city_infos " +
"FROM tmp_click_product_basic " +
"GROUP BY area,product_id "
val df = spark.sql(sql)
// 各区域各商品的点击次数(以及额外的城市列表),再次将查询出来的数据注册为一个临时表
df.createOrReplaceTempView("tmp_area_product_click_count")
}
复制代码
generateTempAreaFullProductClickCountTable(spark)
关联tmp_area_product_click_count表与product_info表,在tmp_area_product_click_count基础上引入商品的详细信息
def generateTempAreaFullProductClickCountTable(spark: SparkSession) {
// 将以前获得的各区域各商品点击次数表,product_id
// 去关联商品信息表,product_id,product_name和product_status
// product_status要特殊处理,0,1,分别表明了自营和第三方的商品,放在了一个json串里面
// get_json_object()函数,能够从json串中获取指定的字段的值
// if()函数,判断,若是product_status是0,那么就是自营商品;若是是1,那么就是第三方商品
// area, product_id, click_count, city_infos, product_name, product_status
// 你拿到到了某个区域top3热门的商品,那么其实这个商品是自营的,仍是第三方的
// 实际上是很重要的一件事
// 技术点:内置if函数的使用
val sql = "SELECT " +
"tapcc.area," +
"tapcc.product_id," +
"tapcc.click_count," +
"tapcc.city_infos," +
"pi.product_name," +
"if(get_json_object(pi.extend_info,'product_status')='0','Self','Third Party') product_status " +
"FROM tmp_area_product_click_count tapcc " +
"JOIN product_info pi ON tapcc.product_id=pi.product_id "
val df = spark.sql(sql)
df.createOrReplaceTempView("tmp_area_fullprod_click_count")
复制代码
}
val areaTop3ProductRDD = getAreaTop3ProductRDD(taskUUID, spark)
def getAreaTop3ProductRDD(taskid: String, spark: SparkSession): DataFrame = {
// 华北、华东、华南、华中、西北、西南、东北
// A级:华北、华东
// B级:华南、华中
// C级:西北、西南
// D级:东北
// case when
// 根据多个条件,不一样的条件对应不一样的值
// case when then ... when then ... else ... end
val sql = "SELECT " +
"area," +
"CASE " +
"WHEN area='China North' OR area='China East' THEN 'A Level' " +
"WHEN area='China South' OR area='China Middle' THEN 'B Level' " +
"WHEN area='West North' OR area='West South' THEN 'C Level' " +
"ELSE 'D Level' " +
"END area_level," +
"product_id," +
"city_infos," +
"click_count," +
"product_name," +
"product_status " +
"FROM (" +
"SELECT " +
"area," +
"product_id," +
"click_count," +
"city_infos," +
"product_name," +
"product_status," +
"row_number() OVER (PARTITION BY area ORDER BY click_count DESC) rank " +
"FROM tmp_area_fullprod_click_count " +
") t " +
"WHERE rank<=3"
spark.sql(sql)
}
复制代码
import spark.implicits._
val areaTop3ProductDF = areaTop3ProductRDD.rdd.map(row =>
AreaTop3Product(taskUUID, row.getAs[String]("area"), row.getAs[String]("area_level"), row.getAs[Long]("product_id"), row.getAs[String]("city_infos"), row.getAs[Long]("click_count"), row.getAs[String]("product_name"), row.getAs[String]("product_status"))
).toDS
areaTop3ProductDF.write
.format("jdbc")
.option("url", ConfigurationManager.config.getString(Constants.JDBC_URL))
.option("dbtable", "area_top3_product")
.option("user", ConfigurationManager.config.getString(Constants.JDBC_USER))
.option("password", ConfigurationManager.config.getString(Constants.JDBC_PASSWORD))
.mode(SaveMode.Append)
.save()
复制代码
温故而知新,本文为了综合复习,进行代码总结,内容粗鄙,勿怪
版权声明:本套技术专栏是做者(秦凯新)平时工做的总结和升华,经过从真实商业环境抽取案例进行总结和分享,并给出商业应用的调优建议和集群环境容量规划等内容,请持续关注本套博客。QQ邮箱地址:1120746959@qq.com,若有任何技术交流,可随时联系。
秦凯新 于深圳