子类化keras层/模型时如何正确使用`@tf.function`?

Ale*_*NON 7 python keras tensorflow tensorflow2.0

我有一个自定义tf.keras.layers.Layer,它只使用 TF 运算符进行某种位解包(将整数转换为布尔值(0 或 1 浮点数))。

class CharUnpack(keras.layers.Layer):

    def __init__(self, name="CharUnpack", *args, **kwargs):
        super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
        # Range [7, 6, ..., 0] to bit-shift integers
        self._shifting_range = tf.reshape(
            tf.dtypes.cast(
                tf.range(7, -1, -1, name='shifter_range'),
                tf.uint8,
                name='shifter_cast'),
            (1, 1, 8),
            name='shifter_reshape')
        # Constant value 0b00000001 to use as bitwise and operator
        self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')

    def call(self, inputs):
        return tf.dtypes.cast(
            tf.reshape(
                tf.bitwise.bitwise_and(
                    tf.bitwise.right_shift(
                        tf.expand_dims(inputs, 2),
                        self._shifting_range,
                    ),
                    self._selection_bit,
                ),
                [x if x else -1 for x in self.compute_output_shape(inputs.shape)]
            ),
            tf.float32
        )

    def compute_output_shape(self, input_shape):
        try:
            if len(input_shape) > 1:
                output_shape = tf.TensorShape(tuple(list(input_shape[:-1]) + [input_shape[-1] * 8]))
            else:
                output_shape = tf.TensorShape((input_shape[0] * 8,))
        except TypeError:
            output_shape = input_shape
        return output_shape

    def compute_output_signature(self, input_signature):
        return tf.TensorSpec(self.compute_output_shape(input_signature.shape), tf.float32)
Run Code Online (Sandbox Code Playgroud)

我尝试对这一层进行基准测试以提高时间性能,如本TF 指南中所示。

inputs = tf.zeros([64, 400], dtype=tf.uint8)

eager = CharUnpack()

@tf.function
def fun(x):
    eager(x)

# Warm-up
eager(inputs)
fun(inputs)

print("Function:", timeit.timeit(lambda: fun(inputs), number=100))
print("Eager:", timeit.timeit(lambda: eager(inputs), number=100))
Run Code Online (Sandbox Code Playgroud)
class CharUnpack(keras.layers.Layer):

    def __init__(self, name="CharUnpack", *args, **kwargs):
        super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
        # Range [7, 6, ..., 0] to bit-shift integers
        self._shifting_range = tf.reshape(
            tf.dtypes.cast(
                tf.range(7, -1, -1, name='shifter_range'),
                tf.uint8,
                name='shifter_cast'),
            (1, 1, 8),
            name='shifter_reshape')
        # Constant value 0b00000001 to use as bitwise and operator
        self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')

    def call(self, inputs):
        return tf.dtypes.cast(
            tf.reshape(
                tf.bitwise.bitwise_and(
                    tf.bitwise.right_shift(
                        tf.expand_dims(inputs, 2),
                        self._shifting_range,
                    ),
                    self._selection_bit,
                ),
                [x if x else -1 for x in self.compute_output_shape(inputs.shape)]
            ),
            tf.float32
        )

    def compute_output_shape(self, input_shape):
        try:
            if len(input_shape) > 1:
                output_shape = tf.TensorShape(tuple(list(input_shape[:-1]) + [input_shape[-1] * 8]))
            else:
                output_shape = tf.TensorShape((input_shape[0] * 8,))
        except TypeError:
            output_shape = input_shape
        return output_shape

    def compute_output_signature(self, input_signature):
        return tf.TensorSpec(self.compute_output_shape(input_signature.shape), tf.float32)
Run Code Online (Sandbox Code Playgroud)

如您所见,我可以获得 10 倍的加速!!!所以,我在@tf.function我的CharUnpack.call方法中添加了装饰器:

inputs = tf.zeros([64, 400], dtype=tf.uint8)

eager = CharUnpack()

@tf.function
def fun(x):
    eager(x)

# Warm-up
eager(inputs)
fun(inputs)

print("Function:", timeit.timeit(lambda: fun(inputs), number=100))
print("Eager:", timeit.timeit(lambda: eager(inputs), number=100))
Run Code Online (Sandbox Code Playgroud)

现在,我希望 theeagerfun, 调用花费相似的时间,但我没有得到任何改善。

Function: 0.01062483999885444
Eager: 0.12658399900101358
Run Code Online (Sandbox Code Playgroud)

此外,在这个SO 答案的第 2.1 节中指出,模型默认情况下是图形编译的(这应该是逻辑),但情况似乎并非如此......

如何正确使用@tf.function装饰器使我的图层始终进行图形编译?

thu*_*v89 1

tldrfun()不返回任何内容,tensorflowautograph 足够聪明,可以意识到这一点并忽略 中发生的所有 TF 计算fun(),而eager(x) 必须执行函数中发生的事情call()。这就是为什么您的执行时间极短fun()。至少我认为正在发生的事情 - 我不是 AutoGraph 专家,所以如果我有任何错误,其他人可能能够纠正我。

问题调查

在我们深入之前,让我们用 git 来简化一下事情。首先我将您的原始代码修改如下。让我们增加数据的大小,以确保涉及足够的数字处理,并且数据传输和其他开销不会主导分析。

inputs = tf.zeros([8192, 400], dtype=tf.uint8)
Run Code Online (Sandbox Code Playgroud)

其次,我去掉了一些计算,例如compute_output_shape()固定形状。里面还带了一些张量定义call()。这样就call()可以处理端到端计算的变量定义。

class CharUnpack(tf.keras.layers.Layer):

    
    def __init__(self, name="CharUnpack", *args, **kwargs):
        super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
        self._shifting_range = None
        self._selection_bit = None

    @tf.function
    def call(self, inputs):

        if not self._shifting_range:
          # Range [7, 6, ..., 0] to bit-shift integers
          self._shifting_range = tf.reshape(
            tf.dtypes.cast(
              tf.range(7, -1, -1, name='shifter_range'),
              tf.uint8,
              name='shifter_cast'
            ),
            (1, 1, 8),
            name='shifter_reshape')
        
        if not self._selection_bit:
          # Constant value 0b00000001 to use as bitwise and operator
          self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')

        return tf.dtypes.cast(
            tf.reshape(
                tf.bitwise.bitwise_and(
                    tf.bitwise.right_shift(
                        tf.expand_dims(inputs, 2),
                        self._shifting_range,
                    ),
                    self._selection_bit,
                ),
                [x if x else -1 for x in self.compute_output_shape(inputs.shape)]
            ),
            tf.float32
        )

    def compute_output_shape(self, input_shape):
        return [8192, 3200]
Run Code Online (Sandbox Code Playgroud)

第三,我设置了number=1timeit 操作,以确保我们一次分析一个调用。这使得它更容易理解。

# The very first call of either approach
print("Eager:", timeit.timeit(lambda: eager(inputs), number=1))
print("Function:", timeit.timeit(lambda: fun(inputs), number=1))
# The second call
print("Eager:", timeit.timeit(lambda: eager(inputs), number=1))
print("Function:", timeit.timeit(lambda: fun(inputs), number=1))
Run Code Online (Sandbox Code Playgroud)

首先我们看一下具体的功能eager()

eager_concrete = eager.call.get_concrete_function(tf.TensorSpec(shape=[None, 400], dtype=tf.uint8))

print(eager_concrete)
Run Code Online (Sandbox Code Playgroud)

这使,

ConcreteFunction call(inputs)
  Args:
    inputs: uint8 Tensor, shape=(None, 400)
  Returns:
    float32 Tensor, shape=(8192, 3200)
Run Code Online (Sandbox Code Playgroud)

我们来看一下具体的功能fun()

fun_concrete = fun.get_concrete_function(tf.TensorSpec(shape=[None, 400], dtype=tf.uint8))

print(fun_concrete)
Run Code Online (Sandbox Code Playgroud)

这使,

ConcreteFunction fun(x)
  Args:
    x: uint8 Tensor, shape=(None, 400)
  Returns:
    NoneTensorSpec()
Run Code Online (Sandbox Code Playgroud)

所以你立刻就会发现它fun()没有返回任何东西,这应该在你的脑海中引起危险信号。更进一步,我们可以看看 AutoGraph 跟踪产生的图表实际上包含了什么。

graph = fun_concrete.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
Run Code Online (Sandbox Code Playgroud)

其输出,

[] -> x
['x'] -> CharUnpack/StatefulPartitionedCall
Run Code Online (Sandbox Code Playgroud)

接下来,如果您对 执行相同操作eager(),您将看到如下列出的所有原始 TF 操作。

[] -> inputs
[] -> StringFormat
['StringFormat'] -> PrintV2
[] -> shifter_range/start
...
['Reshape'] -> Cast
['Cast', '^NoOp'] -> Identity
Run Code Online (Sandbox Code Playgroud)

我们甚至可以查看生成的代码。

print(tf.autograph.to_code(fun.python_function))
Run Code Online (Sandbox Code Playgroud)

这使,

def tf__fun(x):
    with ag__.FunctionScope('fun', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        out = ag__.converted_call(ag__.ld(eager), (ag__.ld(x),), None, fscope)
Run Code Online (Sandbox Code Playgroud)

所以看看代码,它所做的只是生成一个转换后eager的调用x

我不是 AutoGraph 专家,但我想它所做的只是将给定的输入传递xeager.call()并跳过所有计算。所以fun()只是跳过eager.call()函数中所有重要的计算。

我们如何fun()实际进行计算?

只需添加一条return语句即可fun()

@tf.function
def fun(x):
  out = eager(x)
  return out
Run Code Online (Sandbox Code Playgroud)

这使,

Eager: 0.6245606249999582
Function: 0.3163724480000383
Eager: 0.2076279070001874
Function: 0.22467646699988109
Eager: 0.25076841500003866
Function: 0.240701412999897
Run Code Online (Sandbox Code Playgroud)

所以现在我们可以看到两者都eager.call()花费fun()相同的时间。

TF 文档中可以看出,

除了 tf.Variables 之外,tf.function 必须返回其所有输出。

尽管本节强调了问题的另一面,但它可能与这里发生的事情(间接)相关。