Spark自定义外部数据源

背景:有时候咱们须要定义一个外部数据源,而后用spark sql的方式来处理。这样的好处有2点:sql

(1)定义了外部数据源后,用起来很简洁,软件架构清晰,经过sql方式直接使用。apache

(2)容易分层分模块,一层层往上搭建,容易屏蔽实现细节。架构

这时候就须要用到定义外部数据源的方式,spark中使用起来也很简单的,所谓会者不难。app

首先指定个package名,全部的类在统一的package下。好比com.example.hou。ide

而后定义两个东西,一个是DefaultSource,一个是BaseRelation with TableScan的子类。ui

DefaultSource的代码很简单,直接看代码不解释:this

package com.example.hou

import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType


class DefaultSource extends CreatableRelationProvider with SchemaRelationProvider{

  def createRelation(
                      sqlContext: SQLContext,
                      parameters: Map[String, String],
    schema: StructType): BaseRelation = {
    val path = parameters.get("path")

    path match {
      case Some(x) => new TextDataSourceRelation(sqlContext,x,schema)
      case _ => throw new IllegalArgumentException("path is required...")
    }
  }

  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {

    createRelation(sqlContext,parameters,null)

  }
  
}

TextDataSourceRelation的源码:spa

package com.example.hou

import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.TableScan
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row

class TextDataSourceRelation (override val sqlContext: SQLContext,path:String,userSchema: StructType) extends BaseRelation with TableScan with Logging{
  //若是传进来的schema不为空,就用传进来的schema,不然就用自定义的schema
  override def schema: StructType = {
    if(userSchema != null){
      userSchema
    }else{
      StructType(
        StructField("id",LongType,false) ::
          StructField("name",StringType,false) ::
          StructField("gender",StringType,false) ::
          StructField("salary",LongType,false) ::
          StructField("comm",LongType,false) :: Nil
      )
    }
  }

  //把数据读进来,读进来以后把它转换成 RDD[Row]
  override def buildScan(): RDD[Row] = {
    logWarning("this is ruozedata buildScan....")
    //读取数据,变成为RDD
    //wholeTextFiles会把文件名读进来,能够经过map(_._2)把文件名去掉,第一位是文件名,第二位是内容
    val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(_._2)
    //拿到schema
    val schemaField = schema.fields

    //rdd.collect().foreach(println)

    //rdd + schemaField 把rdd和schemaField解析出来拼起来
    val rows = rdd.map(fileContent => {
      //拿到每一行的数据
      val lines = fileContent.split("\n")
      //每一行数据按照逗号分隔,分隔以后去空格,而后转成一个seq集合
      val data = lines.filter(line=>{!line.trim().contains("//")}).map(_.split(",").map(_.trim)).toSeq

      //zipWithIndex
      val result = data.map(x => x.zipWithIndex.map {
        case (value, index) => {

          val columnName = schemaField(index).name
          //castTo里面有两个参数,第一个参数须要给个判断,若是是字段是性别,里面再进行判断再转换一下,若是不是性别就直接用这个字段
          Utils.castTo(if(columnName.equalsIgnoreCase("gender")){
            if(value == "0"){
              "man"
            }else if(value == "1"){
              "woman"
            } else{
              "unknown"
            }
          }else{
            value
          },schemaField(index).dataType)

        }
      })

      result.map(x => Row.fromSeq(x))
    })

    rows.flatMap(x => x)

  }

}

最后一句就是在Main方法中使用:code

package com.example.hou

import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType

object TestApp {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("TextApp")
      .master("local[2]")
      .getOrCreate()

      //定义Schema
    val schema = StructType(
      StructField("id", LongType, false) ::
        StructField("name", StringType, false) ::
        StructField("gender", StringType, false) ::
        StructField("salary", LongType, false) ::
        StructField("comm", LongType, false) :: Nil)

    //只要写到包名就能够了...example.hou,不用这样写...example.hou.DefaultSource
    val df = spark.read.format("com.example.hou")
      .option("path", "C://code//data.txt").schema(schema).load()

    df.show()
    df.createOrReplaceTempView("test")
    spark.sql("select name,salary from test").show()

    println("Application Ended...")
    spark.stop()
  }

}
相关文章
相关标签/搜索