AutoGraph是TF提供的一个很是具备前景的工具, 它可以将一部分python语法的代码转译成高效的图表示代码. 因为从TF 2.0开始, TF将会默认使用动态图(eager execution), 所以利用AutoGraph, 在理想状况下, 能让咱们实现用动态图写(方便, 灵活), 用静态图跑(高效, 稳定).html
可是! 在使用的过程当中, 如无心外确定是会有意外的, 这篇文章就是指出一些AutoGraph和tf.function的奇怪的行为, 让你更愉快地使用它们.python
本文假设读者具备必定的Python和TensorFlow的使用经验.git
对tf1.X有经验的读者应该不会对让咱们又爱又恨的计算图(tf.Graph
)和执行会话(tf.Session
)感到陌生, 一个常规的流程以下:github
y=tf.matmul(a, x) + b
)tf.Session
tf.Session
tf.Session.run
来执行计算图的节点, 被执行的节点会反向追踪全部依赖的须要执行的节点并执行计算.如下是上述过程的一个代码例子:apache
g = tf.Graph() #初始化计算图 with g.as_default(): # 设置为默认计算图 a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) b = tf.Variable(12.) y = tf.matmul(a, x) + b # 描述计算图 init_op = tf.global_variables_initializer() # 待执行节点 with tf.Session() as sess: # 配置会话 sess.run(init_op) # 执行节点 print(sess.run(y)) # 输出结果
在TF 2.0中, 因为默认为动态图, 计算会直接被执行, 也就是说, 咱们不须要缓存
tf.control_dependencies
来声明节点的非直接依赖咱们能够像写普通python代码(or pytorch)同样, 写了就执行:session
a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) b = tf.Variable(12.) y = tf.matmul(a, x) + b print(y.numpy())
通常来讲, eager代码会比执行相同操做的静态图代码的效率低, 由于不少计算图优化的方法只能用在数据流图上.函数
若是想在TF 2.0上构建传统的计算图, 咱们就须要用到tf.function
.工具
TF 2.0的其中一个重要改变就是去除tf.Session
(此处应有掌声). 这个改变会迫使用户用更好的方式来组织代码: 不用再用让人纠结的tf.Session
来执行代码, 就是一个个python函数, 加上一个简单的装饰器.学习
在TF 2.0里面, 若是须要构建计算图, 咱们只须要给python函数加上@tf.function
的装饰器.
上文提到静态图的执行效率更高, 可是加速并非必定的. 通常来讲, 计算图越复杂, 加速效果越明显. 对于复杂的计算图, 好比训练深度学习模型, 得到的加速是巨大的. (译者注: 我的感受仍是要结合实际来看, 若是某一部分的计算既有复杂的计算图, 而计算图的复杂性又带来了额外的 内存消耗
或者计算量, 那么加速会比较明显, 可是不少时候, 好比通常的CNN模型, 主要计算量并不在于图的复杂性, 而在于卷积、矩阵乘法等操做, 加速并不会很明显. 此处想法有待验证)
这个自动将python代码转成图表示代码的工具就叫作AutoGraph.
在TF 2.0中, 若是一个函数被@tf.function
装饰了, 那么AutoGraph将会被自动调用, 从而将python函数转换成可执行的图表示.
在第一次调用被@tf.function
装饰的函数时, 下列事情将会发生:
tf.
API只会定义一个生成tf.Tensor
输出的节点while
→tf.while
,for
→tf.while
,if
→tf.cond
,assert
→tf.assert
...)tf.control_dependencies
,以便在执行第i+1
行时确保第i
行已经被执行. 至此计算图已经肯定map [id] = graph
下一节将会具体阐述如何将TF 1.X代码块分别改写到eager和计算图版本.
要使用tf.function
, 第一步须要先将TF 1.X的设计计算图的代码放进python函数里面.
def f(): a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) b = tf.Variable(12.) y = tf.matmul(a, x) + b return y
应为TF 2.0默认是eager的, 咱们能够直接执行该函数(不须要tf.Session
):
print(f().numpy())
咱们就会获得输出:
[[22. 22.] [23. 13.]]
咱们能够直接用@tf.function
来装饰函数f
, 咱们在原来f
的基础上加上宇宙第一的debug大法: print
来更好地看看究竟发生了什么.
@tf.function def f(): a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) b = tf.Variable(12.) y = tf.matmul(a, x) + b print("PRINT: ", y) tf.print("TF-PRINT: ", y) return y f()
因此发生了什么呢?
@tf.function
将函数f
包进了tensorflow.python.eager.def_function.Function
这个对象, 函数f
被赋予到了这个对象的.python_function
属性.f()
被执行的时候, 计算图会同时被构建, 可是计算不会执行, 所以咱们会获得如下结果, tf.
的操做不会被执行:PRINT: Tensor("add:0", shape=(2, 2), dtype=float32)
ValueError: tf.function-decorated function tried to create variables on non-first call.
在 RFC: Functions, not Session里面有个很是明确的指示:
State (liketf.Variable
objects) are only created the first time the function f is called. 状态(好比tf.Variable
) 只会在函数被第一次调用时建立.
可是 Alexandre Passos指出, 在函数转换成图表示时, 咱们没有办法肯定tf.function
调用了多少次函数, 所以咱们在第一次调用函数f
时, 在图构建的过程当中, 可能会被执行了屡次, 这就致使了上述错误.
形成这个错误的根源在于一样的命令在动态图和静态图中的不一致性. 在动态图中, tf.Variable
时一个普通的python变量, 超出了其做用域范围就会被销毁. 而在静态图中, tf.Variable
则是计算图中一个持续存在的节点, 不受python的做用域的影响. 所以, 这是使用tf.function
的第一个教训:
将一个在动态图中可行的函数转换成静态图须要用静态图的方式思考该函数是否可行
那么咱们能够怎样去规避这个错误呢?
tf.Variable
做为函数的参数传入tf.Variable
tf.Variable
做为类属性来调用这里指方法2和方法3. 显然的, 咱们推荐使用方法3:
class F(): def __init__(self): self._b = None @tf.function def __call__(self): a = tf.constant([[10, 10], [11., 1.]]) x = tf.constant([[1., 0.], [0., 1.]]) if self._b is None: self._b = tf.Variable(12.) y = tf.matmul(a, x) + self._b print("PRINT: ", y) tf.print("TF-PRINT: ", y) return y f = F() f()
咱们以后会看到, 咱们并不能随意地用tf.function
来转化eager的代码并达到加速的目的, 咱们须要想象一下转化是怎么完成的, 在转python的代码到图操做的时候究竟发生了什么, 这些转化包含了什么黑魔法. 这里的例子比较简单, 咱们会在接下来的文章中更深刻的探讨.
@tf.function def f(b): a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) y = tf.matmul(a, x) + b print("PRINT: ", y) tf.print("TF-PRINT: ", y) return y b = tf.Variable(12.) f(b)
上述函数会获得咱们想要的结果, 另外, 做为参数被传入的变量可以在函数中直接更新, 而更新后的值会在函数外也适用. 下面的代码会打印出1,2,3
a = tf.Variable(0) @tf.function def g(x): x.assign_add(1) return x print(g(a)) print(g(a)) print(g(a))
@tf.function
装饰器来将python代码转成图表示代码tf.Variable
在以后的部分咱们会更加深刻地探讨输入参数类型对效率的影响, 以及python操做的转换细节.
声明: 本文翻译自Paolo Galeone的博客, 已取得做者的赞成, 如需转载本文请联系本人
Disclaimer: This is a translation of the article Analyzing tf.function to discover AutoGraph strengths and subtleties by Paolo Galeone.