重要|Spark driver端获得executor返回值的方法

重要|Spark driver端获得executor返回值的方法

浪院长 浪尖聊大数据 mysql

重要|Spark driver端获得executor返回值的方法
有人说spark的代码不优雅,这个浪尖就忍不了了。实际上,说spark代码不优雅的主要是对scala不熟悉,spark代码我以为仍是很赞的,最值得阅读的大数据框架之一。
今天这篇文章不是为了争辩Spark 代码优雅与否,主要是讲一下理解了spark源码以后咱们能使用的一些小技巧吧。
spark 使用的时候,总有些需求比较另类吧,好比有球友问过这样一个需求:sql

浪尖,我想要在driver端获取executor执行task返回的结果,好比task是个规则引擎,我想知道每条规则命中了几条数据,请问这个怎么作呢?

这个是否是很骚气,也很常见,按理说你输出以后,在mysql里跑条sql就好了,可是这个每每显的比较麻烦。并且有时候,在 driver可能还要用到这些数据呢?具体该怎么作呢?数据库

大部分的想法估计是collect方法,那么用collect如何实现呢?你们本身能够考虑一下,我只能告诉你不简单,不如输出到数据库里,而后driver端写sql分析一下。apache

还有一种考虑就是使用自定义累加器。这样就能够在executor端将结果累加而后在driver端使用,不过具体实现也是很麻烦。你们也能够本身琢磨一下下~数组

那么,浪尖就给你们介绍一个比较经常使用也比较骚的操做吧。框架

其实,这种操做咱们最早想到的应该是count函数,由于他就是将task的返回值返回到driver端,而后进行聚合的。咱们能够从idea count函数点击进去,能够看到...elasticsearch

def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum

也便是sparkcontext的runJob方法。
Utils.getIteratorSize _这个方法主要是计算每一个iterator的元素个数,也便是每一个分区的元素个数,返回值就是元素个数:ide

/**
   * Counts the number of elements of an iterator using a while loop rather than calling
   * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower
   * in the current version of Scala.
   */
  def getIteratorSize[T](iterator: Iterator[T]): Long = {
    var count = 0L
    while (iterator.hasNext) {
      count += 1L
      iterator.next()
    }
    count
  }

而后就是runJob返回的是一个数组,每一个数组的元素就是咱们task执行函数的返回值,而后调用sum就获得咱们的统计值了。函数

那么咱们彻底能够借助这个思路实现咱们开头的目标。浪尖在这里直接上案例了:oop

import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.elasticsearch.hadoop.cfg.ConfigurationOptions

object es2sparkRunJob {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local[*]").setAppName(this.getClass.getCanonicalName)

    conf.set(ConfigurationOptions.ES_NODES, "127.0.0.1")
    conf.set(ConfigurationOptions.ES_PORT, "9200")
    conf.set(ConfigurationOptions.ES_NODES_WAN_ONLY, "true")
    conf.set(ConfigurationOptions.ES_INDEX_AUTO_CREATE, "true")
    conf.set(ConfigurationOptions.ES_NODES_DISCOVERY, "false")
    conf.set("es.write.rest.error.handlers", "ignoreConflict")
    conf.set("es.write.rest.error.handler.ignoreConflict", "com.jointsky.bigdata.handler.IgnoreConflictsHandler")

    val sc = new SparkContext(conf)
    import org.elasticsearch.spark._

    val rdd = sc.esJsonRDD("posts").repartition(10)

    rdd.count()
    val func = (itr : Iterator[(String,String)]) => {
      var count = 0
      itr.foreach(each=>{
        count += 1
      })
      (TaskContext.getPartitionId(),count)
    }

    val res = sc.runJob(rdd,func)

    res.foreach(println)

    sc.stop()
  }
}

例子中driver端获取的就是每一个task处理的数据量。效率高,并且操做灵活高效~是否是很骚气~~

相关文章
相关标签/搜索