pro*_*cer 24 keras tensorflow tf.keras tensorflow2.0 tensorflow2.x
一个官方教程上@tf.function
说:
为了获得最佳性能并使您的模型可在任何地方部署,请使用 tf.function 从您的程序中制作图形。感谢 AutoGraph,数量惊人的 Python 代码仅适用于 tf.function,但仍有一些陷阱需要警惕。
主要的收获和建议是:
- 不要依赖 Python 的副作用,如对象突变或列表追加。
- tf.function 最适合 TensorFlow 操作,而不是 NumPy 操作或 Python 原语。
- 如有疑问,请使用 for x in y 习语。
它只提到了如何实现带@tf.function
注释的函数,而没有提到何时使用它。
关于如何决定我是否至少应该尝试用 注释函数是否有启发tf.function
?似乎没有理由不这样做,除非我懒得去除副作用或更改诸如range()
-> 之类的东西tf.range()
。但如果我愿意这样做......
是否有任何理由不@tf.function
用于所有功能?
pro*_*ast 26
TLDR:这取决于您的职能以及您是在生产中还是在开发中。不要使用tf.function
,如果你希望能够方便地调试功能,或者如果它落在下亲笔签名或限制tf.v1代码的兼容性。我强烈建议观看 Inside TensorFlow 谈论AutoGraph和Functions,而不是 Sessions。
下面我将分解原因,这些原因均来自 Google 在线提供的信息。
通常,tf.function
装饰器会导致将函数编译为执行 TensorFlow 图的可调用对象。这需要:
tf.function
tf.function
修饰使用 AutoGraph如果你想使用 AutoGraph,tf.function
强烈建议使用而不是直接调用 AutoGraph。原因包括:自动控制依赖,某些 API 需要它,更多缓存和异常助手(来源)。
tf.function
tf.function
修饰使用 AutoGraphDetailed information on AutoGraph limitations is available.
tf.function
, but this is subject to change as tf.v1 code is phased out (Source)It is not allowed to create variables more than once, such as v
in the following example:
@tf.function
def f(x):
v = tf.Variable(1)
return tf.add(x, v)
f(tf.constant(2))
# => ValueError: tf.function-decorated function tried to create variables on non-first call.
Run Code Online (Sandbox Code Playgroud)
In the following code, this is mitigated by making sure that self.v
is only created once:
class C(object):
def __init__(self):
self.v = None
@tf.function
def f(self, x):
if self.v is None:
self.v = tf.Variable(1)
return tf.add(x, self.v)
c = C()
print(c.f(tf.constant(2)))
# => tf.Tensor(3, shape=(), dtype=int32)
Run Code Online (Sandbox Code Playgroud)
Changes such as to self.a
in this example can't be hidden, which leads to an error since cross-function analysis is not done (yet) (Source):
class C(object):
def change_state(self):
self.a += 1
@tf.function
def f(self):
self.a = tf.constant(0)
if tf.constant(True):
self.change_state() # Mutation of self.a is hidden
tf.print(self.a)
x = C()
x.f()
# => InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(), dtype=int32)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=cond_true_5, id=5477800528); accessed from: FuncGraph(name=f, id=5476093776).
Run Code Online (Sandbox Code Playgroud)
Changes in plain sight are no problem:
class C(object):
@tf.function
def f(self):
self.a = tf.constant(0)
if tf.constant(True):
self.a += 1 # Mutation of self.a is in plain sight
tf.print(self.a)
x = C()
x.f()
# => 1
Run Code Online (Sandbox Code Playgroud)
This if statement leads to an error because the value for else needs to be defined for TF control flow:
@tf.function
def f(a, b):
if tf.greater(a, b):
return tf.constant(1)
# If a <= b would return None
x = f(tf.constant(3), tf.constant(2))
# => ValueError: A value must also be returned from the else branch. If a value is returned from one branch of a conditional a value must be returned from all branches.
Run Code Online (Sandbox Code Playgroud)
小智 4
tf.function 在创建和使用计算图时很有用,它们应该在训练和部署中使用,但大多数函数都不需要它。
假设我们正在构建一个特殊的层,它将成为更大模型的一部分。我们不希望在构造该层的函数之上使用 tf.function 装饰器,因为它只是层外观的定义。
另一方面,假设我们要进行预测或使用某些函数继续训练。我们想要装饰器 tf.function 因为我们实际上是使用计算图来获取一些值。
一个很好的例子是构建编码器-解码器模型。不要将装饰器放在创建编码器或解码器或任何层的函数周围,这只是它将做什么的定义。一定要将装饰器放在“训练”或“预测”方法周围,因为这些方法实际上将使用计算图进行计算。
归档时间: |
|
查看次数: |
5362 次 |
最近记录: |