在计算时,为了节省内存,不把全部的数据一次所有加载到内存中,有一种设计模式叫迭代器模式。设计模式
迭代器模式:在逻辑代码执行时,真正的逻辑并未执行,而是建立了新的迭代器,新的迭代器保存着对当前迭代器的引用从而造成链表,每一个迭代器须要实现hasNext(),next()两个方法。当触发计算时,最后一个建立的迭代器会调用next方法,next方法会调用父迭代器的next方法。ide
例如:函数
val list = List("a a", "b d", "c e") val it = list.iterator it.flatMap(_.split(" ")).map((_, 1)).filter(_._1 != "").foreach(println)
这个例子中it是初始迭代器,后面每一个方法都会生成一个新的迭代器,但并不进行迭代计算,到最后foreach方法(相似action算子),开始执行迭代计算了。this
咱们依次展开:spa
def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] { // f做用在上游单条数据的结果转换成的iterator private var cur: Iterator[B] = empty private def nextCur() { cur = f(self.next()).toIterator } def hasNext: Boolean = { while (!cur.hasNext) { if (!self.hasNext) return false nextCur() } true } def next(): B = (if (hasNext) cur else empty).next() }
flatMap方法是建立了一个AbstractIterator的匿名内部类,并实现了hasNext和next两个方法。每当调用next时,会先调用hasNext,在hasNext中,调上游的iterator的next方法获取上游这条数据的返回结果,再对这条结果执行用户传入的函数f并返回结果后,将其转换为iterator,再返回这个iterator的next的结果。线程
def map[B](f: A => B): Iterator[B] = new AbstractIterator[B] { def hasNext = self.hasNext def next() = f(self.next()) }
map与flatMap的代码模板同样,逻辑更简单,只是对上游的next返回结果执行用户传入的函数,再返回。设计
def filter(p: A => Boolean): Iterator[A] = new AbstractIterator[A] { private var hd: A = _ private var hdDefined: Boolean = false def hasNext: Boolean = hdDefined || { do { if (!self.hasNext) return false hd = self.next() } while (!p(hd)) hdDefined = true true } def next() = if (hasNext) { hdDefined = false; hd } else empty.next() }
filter中,调用hasNext时,先调用上游iterator的hasNext,若是返回false,那么直接返回false。若是上游的hasNext返回true,就取出上游的next结果,并将用户传入的判断函数p做用在这个结果上,若为true,则退出循环,并将hdDefine置为true;若p的结果为false,则继续从上游取下一条数据让p判断。code
def foreach[U](f: A => U) { while (hasNext) f(next()) }
遍历迭代器,将每一个元素传给用户传入的函数f中执行。继承
在spark的每一个任务中,都是以迭代器模式进行计算的。而每一个迭代器的链表对应每一个分区中的数据。RDD的每一个算子会生成一个新的RDD,新的RDD会保存对前一个RDD的引用,而且会保存传入到算子中的用户定义函数。ip
例如:
def map[U: ClassTag](f: T => U): RDD[U] = withScope { val cleanF = sc.clean(f) new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF)) }
这个map算子会返回一个MapPartitionsRDD,MapPartitionsRDD中含有当前this这个RDD的引用,并把用户定义函数f转换成做用于iterator的函数传入到MapPartitionsRDD中。
RDD中有个抽象方法compute,MapPartitionsRDD中实现以下:
override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context))
从父RDD(firstParent[T])获取迭代器,这个过程须要分区信息split和任务上下文。再map算子中转换后的用户定义函数做用在这个迭代器上。
compute方法同迭代器模式相似,也是不断从上游RDD获取的迭代器,这样来得到一个迭代器的链表,这个链表就是一个task要执行的任务。
为了说明这个过程,咱们从Executor源码来找寻。
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { val tr = new TaskRunner(context, taskDescription) runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) }
Executor源码中有个launchTask方法,会建立TaskRunner,将TaskRunner交给线程池执行。TaskRunner是什么呢?
在Executor源码中有一个内部类,TaskRunner,它是一个线程的任务:
class TaskRunner( execBackend: ExecutorBackend, private val taskDescription: TaskDescription) extends Runnable {
继承Runnable必须实现run方法,找到run方法,在run方法中找到了以下代码:
val res = task.run( taskAttemptId = taskId, attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem) threwException = false
点进这里task的run,会在Task类中找到runTask(context),这个runTask是Task类的抽象方法,会被Task的子类实现。好比ResultTask,这个子类是最后collect类型的action算子出发的任务类。在ResultTask中,runTask方法调用了rdd的iterator方法来获取iterator,并将用户定义的方法做用到这个iterator上。
override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. val threadMXBean = ManagementFactory.getThreadMXBean val deserializeStartTime = System.currentTimeMillis() val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime } else 0L func(context, rdd.iterator(partition, context)) }
这个rdd的iterator方法会获取父rdd的迭代器或调用compute方法。
final def iterator(split: Partition, context: TaskContext): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { getOrCompute(split, context) } else { computeOrReadCheckpoint(split, context) } } private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { if (isCheckpointedAndMaterialized) { firstParent[T].iterator(split, context) } else { compute(split, context) } }
spark每一个任务都是由向前依赖串联起来RDD链表生成的iterator链表构成的,任务执行由最后的一个iterator的迭代开始,调用上游的迭代器的next,直到迭代到第一个iterator。这样避免了将全部数据先加载到内存中,而每次计算都只从源头取一条数据,大大节省了内存。