背景:有时候咱们须要定义一个外部数据源,而后用spark sql的方式来处理。这样的好处有2点:sql
(1)定义了外部数据源后,用起来很简洁,软件架构清晰,经过sql方式直接使用。apache
(2)容易分层分模块,一层层往上搭建,容易屏蔽实现细节。架构
这时候就须要用到定义外部数据源的方式,spark中使用起来也很简单的,所谓会者不难。app
首先指定个package名,全部的类在统一的package下。好比com.example.hou。ide
而后定义两个东西,一个是DefaultSource,一个是BaseRelation with TableScan的子类。ui
DefaultSource的代码很简单,直接看代码不解释:this
package com.example.hou import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, SchemaRelationProvider} import org.apache.spark.sql.types.StructType class DefaultSource extends CreatableRelationProvider with SchemaRelationProvider{ def createRelation( sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { val path = parameters.get("path") path match { case Some(x) => new TextDataSourceRelation(sqlContext,x,schema) case _ => throw new IllegalArgumentException("path is required...") } } override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = { createRelation(sqlContext,parameters,null) } }
TextDataSourceRelation的源码:spa
package com.example.hou import org.apache.spark.sql.types.LongType import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.rdd.RDD import org.apache.spark.sql.sources.TableScan import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StringType import org.apache.spark.internal.Logging import org.apache.spark.sql.Row class TextDataSourceRelation (override val sqlContext: SQLContext,path:String,userSchema: StructType) extends BaseRelation with TableScan with Logging{ //若是传进来的schema不为空,就用传进来的schema,不然就用自定义的schema override def schema: StructType = { if(userSchema != null){ userSchema }else{ StructType( StructField("id",LongType,false) :: StructField("name",StringType,false) :: StructField("gender",StringType,false) :: StructField("salary",LongType,false) :: StructField("comm",LongType,false) :: Nil ) } } //把数据读进来,读进来以后把它转换成 RDD[Row] override def buildScan(): RDD[Row] = { logWarning("this is ruozedata buildScan....") //读取数据,变成为RDD //wholeTextFiles会把文件名读进来,能够经过map(_._2)把文件名去掉,第一位是文件名,第二位是内容 val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(_._2) //拿到schema val schemaField = schema.fields //rdd.collect().foreach(println) //rdd + schemaField 把rdd和schemaField解析出来拼起来 val rows = rdd.map(fileContent => { //拿到每一行的数据 val lines = fileContent.split("\n") //每一行数据按照逗号分隔,分隔以后去空格,而后转成一个seq集合 val data = lines.filter(line=>{!line.trim().contains("//")}).map(_.split(",").map(_.trim)).toSeq //zipWithIndex val result = data.map(x => x.zipWithIndex.map { case (value, index) => { val columnName = schemaField(index).name //castTo里面有两个参数,第一个参数须要给个判断,若是是字段是性别,里面再进行判断再转换一下,若是不是性别就直接用这个字段 Utils.castTo(if(columnName.equalsIgnoreCase("gender")){ if(value == "0"){ "man" }else if(value == "1"){ "woman" } else{ "unknown" } }else{ value },schemaField(index).dataType) } }) result.map(x => Row.fromSeq(x)) }) rows.flatMap(x => x) } }
最后一句就是在Main方法中使用:code
package com.example.hou import org.apache.spark.sql.SparkSession import org.apache.spark.SparkConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StringType object TestApp { def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("TextApp") .master("local[2]") .getOrCreate() //定义Schema val schema = StructType( StructField("id", LongType, false) :: StructField("name", StringType, false) :: StructField("gender", StringType, false) :: StructField("salary", LongType, false) :: StructField("comm", LongType, false) :: Nil) //只要写到包名就能够了...example.hou,不用这样写...example.hou.DefaultSource val df = spark.read.format("com.example.hou") .option("path", "C://code//data.txt").schema(schema).load() df.show() df.createOrReplaceTempView("test") spark.sql("select name,salary from test").show() println("Application Ended...") spark.stop() } }