Control Flow in Tensorflow TF中的控制流解析

#写在前面 本文翻译自Tensorflow团队的文章Tensorflow Control Flow Implementation,部份内容加入了笔者本身的理解,若有不妥之处还望各位指教。异步

目录

  • 概览
  • 控制流核心概念
  • 控制流结构的编译
  • 条件表达式
  • while循环
  • 实现
  • 分布式条件表达式
  • 分布式while循环
  • 自动微分

概览

本文将会介绍当前在Tensorflow中控制流操做的设计和实现。这是一篇基于原始设计的描述性文档,设计的细节还请参考源代码。分布式

本文将要讲述的内容是:函数

  • 介绍Tensorflow为了处理控制流加入的5个核心的操做;
  • 展现高层的控制流是如何经过5个基础操做融入数据流图的;
  • 解释加入了控制流的数据流图是怎样被Tensorflow运行时执行的,包括融合了多种设备(CPU,GPU,TPU)的分布式执行方式;
  • 描述了对控制流结构如何自动求导;

控制流核心概念

Tensorflow中控制流的基础设计理念是,经过引入少许的简单基础操做,为多样的Tensorflow应用提供丰富的控制流表达。咱们指望这些操做灵活且富有表现力,可以做为高层的领域专用语言(DSL,Domain Specific Language)的编译目标。它们须要很方便的嵌入Tensorflow目前的数据流模型中,而且能够方便的进行并行的、分布式的执行以及自动求导。本节将介绍这5种控制流相关的基本操做。它们与Dennis和Arvind在数据流机(dataflow machines)中引入的控制流操做很像。使用Switch和Merge可使咱们事先条件控制,将这5种基础操做组合起来,可使咱们实现while循环。oop

图1

在Tensorflow中,每个op都会在一个执行帧(execution frame)中被执行,控制流操做负责建立和管理这些执行帧。好比,对于while循环,Tensorflow的运行时会建立一个执行帧,而后将全部属于该while循环的操做放在这个执行帧中执行。不一样执行帧中的操做能够并行执行,只要它们之间没有数据依赖。url

Switch:一个Switch操做根据控制输入p的布尔值,将一个输入张量d推动到某一个输出(二选一)。只有到Switch操做的两个输入都准备好以后,它才会执行。spa

MergeMerge操做将它的其中一个输入推向输出。当一个Merge操做的任意一个输入准备好以后,Merge操做就会执行。在多个输入都准备好的状况下,Merge操做的输出不肯定。.net

Enter(name)Enter操做将它的输入推向名为name的执行帧。Enter操做其实是把一个执行帧的张量推向它的子执行帧。同一个子执行帧上可能会有多个Enter操做,它们将不一样的张量推向子执行帧。当输入准备好以后,Enter操做就会执行。一个新的执行帧在它的第一个Enter操做执行以后开始执行。翻译

ExitExit操做,将一个张量从一个子执行帧推向它的父执行帧。它的做用是将张量从子执行帧返回给父执行帧。一个子执行帧可能有多个Exit操做指向父执行帧,每一个操做都会异步的将一个张量返回给父执行帧。当它的输入准备好以后,Exit操做开始执行。设计

NextIterationNextIteration操做将一个张量从当前执行帧的一轮迭代传递到下一轮迭代。Tensorflow的运行时在执行帧内部保存了一个迭代轮数。任何一个在执行帧中执行的操做都有惟一的一个迭代轮数的属性,它能够帮助咱们分辨一个迭代运算中不一样的执行轮次。注意在一个执行帧中可能会有多个NextIteration操做。当执行帧的第N轮执行的第一个NextIteration操做开始执行时,Tensorflow的运行时开始执行第N+1轮的迭代。当更多的张量经过了NextIteration操做进入新的执行轮次时,新执行轮次中更多的操做就会开始运行。当输入准备完成以后,NextIteration操做开始执行。code

控制流结构的编译

有了这5种基础的操做,高级的程序部件,例如条件表达式和whiile循环就能够被编译进入数据流图,而后被Tensorflow的运行时执行。下面咱们来看一下条件表达式和while循环是如何在Tensorflow内部实现的。

条件表达式

如下是构建条件表达式cond(pred, fn1, fn2)的数据流图的高层伪代码。为了简化,咱们忽略了实际使用中的细节,读者能够在control_flow_ops.py中找到实现细节:

//构建true分支图
context_t = CondContext(pred, branch=1)
res_t = context_t.Call(fn1)

//构建false分支图
context_t = CondContext(pred, branch=0)
res_f = context_f.Call(fn2)

//为输出添加Merge节点
merges = [Merge([f,t]) for (f,t) in zip(res_f, res_t)]
return merges

对于条件表达式的每个分支,咱们建立了一个新的控制流上下文,而且在上下文中调用了图构建的函数(fn1或者fn2)。条件上下文容许咱们获取任意的外部张量(不在上下文中建立的),而且插入一个合适的Switch操做来保证它会进入一个分支。这就保证了,只有当这个分支被选择时,它对应的操做才会被执行。因为Tensorflow是异步执行的,外部的张量可能在不一样的时间到达,所以咱们为每个外部张量准备了一个Switch操做来最大化并行度。

每一个分支都返回了张量的列表(res_t或者res_f),所以咱们又添加了一个Merge操做来对结果进行合并,这样只要任何一个分支执行成功了,就能获得输出(前面讲到,对于Merge操做,只要其中一个输入准备好了,就会产生输出)。

让咱们来看一个简单的例子:

图2

tf.cond(x<y, lambda: tf.add(x,z), lambda: tf.square(y))

在生成的数据流图中,Switch操做的插入是为了控制x,y,z张量的流动。在true/false分支,只有Switch操做的true/false的输出才会被使用。因为Add操做的输入来自Switch操做的true分支,所以只有x小于y时,Add操做才会被执行。一样的,只有x大于等于y时,Square操做才会被执行。最终Merge操做发送Add或者Square的结果。若是条件表达式有多个结果,那么将会有多个Merge操做,每一个结果对应一个Merge操做。

固然,利用Switch和Merge操做实现条件表达式还有不少方法,咱们选择当前的实现,主要是由于它更容易进行自动求导。

while循环

如下是构建数据流图中while循环的高层伪代码:

while_context = WhileContext()
while_context.Enter()

//为每个循环变量添加Enter节点
enter_vars = [Enter(x, frame_name) for x in loop_vars]

//添加Merge节点,注意input[1]将会在后面被迭代
merge_vars = [Merge([x,x]) for x in enter_vars]

//构建循环条件子图
pred_result = pred(*merge_vars)

//添加Switch节点
switch_vars = [Switch(x, pred_result) for x in merge_vars]

//构建循环体子图
body_result = body(*[x[1] for x in switch_vars])

//添加NextIteration节点
next_vars = [NextIteration(x) for x in body_result]

//构建循环
for m,v in zip(merge_vars, next_vars):
    m.op._update_input(1,v)

//添加Exit节点
exit_vars = [Exit(x[0]) for x in switch_vars]
while_context.Exit()
return exit_vars

整个while循环图建立在while循环的控制流上下文中。整个思路比较简单。

从循环变量开始,咱们为它们分别添加一个Enter操做和一个Merge操做。咱们使用它们的结果(merge_vars)来构建判断子图,从而计算循环终止条件。

在添加了Switch操做以后,咱们使用Switch操做的true分支来构建循环体子图。循环体的结果须要进入下一轮迭代,所以咱们添加了一个NextIteration操做,而且将其输出指向Merge操做的第二个输入,这样就造成了循环,容许咱们在执行图是不断的运行一样的一组操做。

Switch操做的false输出是整个while循环的输出,所以咱们在它后面加入了Exit操做,来返回运算结果。与条件表达式相似,while循环的上下文被用来追踪在pred和lambda中使用的外部张量。这些外部张量被看作是循环常数,咱们自动为每个外部张量插入了一个Enter操做,使它在while循环的上下文内部可以被访问。嵌套的循环须要添加嵌套的Enter操做。

一样的,让咱们看一个简单的例子:

图3

tf.while_loop(lambda i:i<10, lambda i: tf.add(i,1),[0])

如上图所示,咱们只有一个循环变量。若是有多个循环变量,咱们须要添加多个Enter,Merge,Switch,NextIteration和Exit操做。这使得跨循环和跨迭代轮次的执行成为可能。你可能注意到咱们省略了常量的表示方法,若是你想要理解更深层次的细节,请查看源代码。

这种对于条件表达式和while循环的支持,使得咱们能够表达任意嵌套的条件和循环。例如,一个循环体内可能嵌套着另一个循环体。TF保证每一个循环被赋予了一个惟一的帧名称。

实现

Tensorflow的运行时负责对数据流图进行执行。下面咱们先来对此作一个快速的概览。

为了在多台设备上运行,TF自动将计算操做分配到不一样的设备上。基于设备分配,TF自动的将数据流图划分红子图,每台设备有一个子图对应。当数据流图的一条边被图分割切段时(边两侧的节点分配在两台不一样的设备上),咱们自动的插入一对send和recv节点,以便在设备间传输数据。一对send和recv节点经过一个惟一的键实现通讯,recv节点主动的从send节点拉取数据。例如,如下就是将原图分割到两台设备后的结果。TF对于分割没有添加任何限制,只要一个节点可以在一台设备上进行运算,就能够被分配到这台设备。

图4

若是一个子图被分配到一个设备上运行,那么这个设备将会使用隶属于它的执行器来执行这个子图。执行器从source节点开始,依次执行已经准备好的节点。除了Merge节点以外,对于任何一个其余节点来讲,只要它的输入准备好了,这个节点就能够开始执行了。注意一张子图中全部的recv节点都被认为是source节点。

若是没有控制流,图执行的过程会很是的直接:每一个节点仅被执行一次,而且当全部节点都执行结束以后,整个图的执行就完成了。控制流的引入带来了必定的复杂性。有了控制流,一个节点可能被执行任意次(甚至包括0次)。执行器须要管理对于同一个节点的多个同时存在的执行实例,而且决定计算图合适执行结束。

为了追踪计算中产生的张量,执行器中的张量被使用一个形如(value, is_dead, tag)的元组来标识,value是张量值,is_dead是一个布尔值,用来标识这个张量是否在一个未执行的条件分支上,tag是这个张量的惟一标识(产生张量的节点的执行实例)。本质上,tag定义了执行的上下文,在同一个执行上下文下,一个操做最多被执行一次。tag是send/recv之间通讯的键的一部分,用来辨识同一对send/recv节点的不一样执行。

执行器遵循了以下的执行规则(注意,某个节点的全部输入都必须包含一样的tag)

Switch(p,d) = (r1,r2)
r1 = (value(d), p || is_dead(d),tag(d))
r2 = (value(d), !p || is_dead(d),tag(d))

Merge(d1,d2) = r
r = if is_dead(d1) then d2 else d1

Enter(d, frame_name) = r
value(r) = value(d)
is_dead(r) = is_dead(d)
tag(r) = tag(d)/frame_name/0

Exit(d) = r
value(r) = value(d)
is_dead(r) = is_dead(d)
tag(r) = tag1 where tag(d)=tag1/frame_name/n

NextIteration(d) = d1
value(d1) = value(d)
is_dead(d1) = is_dead(d)
tag(d1) = tag1/frame_name/(n+1) where tag(d) = tag1/frame_name/n

Op(d1,...,dm) = (r1,...,rn)
value(ri) = Op.Compute(value(d1),...,value(dm)) if !is_dead(ri)
is_dead(ri) = any(is_dead(d1),...,is_dead(dm)), for all i
tag(ri) = tag(d1), for all i

最后一个规则适用于全部的非控制流节点。注意只有当全部的输入都有效时,计算才会执行。若是有一个dead输入,咱们将会跳过计算,而将dead信号传递下去。对于dead信号的传递将有助于支持控制流的分布式执行。

分布式条件表达式

对于分布式执行来讲,一个条件表达式可能被分配到了不一样的设备上,以下图所示:

图5

因为每个recv节点都是source节点,而且随时可能会开始执行,在设备B上的recv节点甚至在出于未选择的条件分支上时也会执行。为了让出于未选择的分支上的recv节点的执行合理化,咱们将is_dead标签经过send节点跨设备传输到recv节点。这种信息会一直跨越设备传输下去。这种简单的传输机制使得在分布式环境下的条件判断更加天然,也有助于分布式环境下的while循环。

分布式的while循环

在分布式环境下,一个while循环(特别是循环体),可能被分割到不一样的设备上。若是咱们简单的应用分割逻辑,而后在跨设备的节点之间插入send/recv,那么设备上的局部执行器将缺乏准确执行while循环的信息。

图6

让咱们经过一个例子来认识这个问题。在上述例子中,Op在循环体中,而且被分配给了设备B。一个简单的分割可能会在Switch和Op之间插入一对send/recv节点来执行跨设备的数据传输。然而,这样是没法工做的,由于设备B并不知道recv和Op操做是处在一个循环当中的,在执行完Op一次以后,设备B上的执行器就会认为,它的工做已经完成了(从设备B的角度看,它只须要从recv获取数据,执行Op,而后将结果经过send发送出去,执行就结束了)。解决方案是,重写数据流图,在while循环体分配到的每一个设备上,添加一个控制循环状态机(以下图中所示)。标量0被用来做为Enter节点的输入。

图7

这些控制循环为设备上的执行器提供了足够的信息,使得它们能够像之前同样独立的执行,同时经过send/recv与其它设备通讯。注意到图中的虚线表明了控制输入。

(具体执行过程分为0次执行,和大于等于1次执行两种状况讨论,这里就不写了,你们能够自行分析)

注意到执行中有很是多的并行执行。例如,在接收到P以后,设备B能够开始下一轮迭代,或者中止执行。一个设备可能同时存在并行的多个执行轮次,而且两个不一样的设备还能够同时处在同一个循环的不一样迭代轮次上。

这种while循环的分布式执行方式带来的开销是,任何一个参与的设备都必须在每个迭代轮次里,接收来自产生P的设备传递过来的布尔张量。因为执行过程是高度并行的,这种开销能够忽略不计了。

下图展现了当一个while循环被分割到不一样的设备上时是什么样子。每一个分割的部分都被添加了一个控制循环结构,用来控制while循环内部的recv操做。重写以后的新图与原图是语义等价的。

图8

对于嵌套的while循环,咱们按照下图所示的方式将控制循环堆叠起来。注意若是一台设备仅包含了外层循环的节点,咱们不会在它上面添加与内层循环有关的控制循环结构。

图9

自动微分

待补充。

相关文章
相关标签/搜索