Tensorflow build() 如何从 tf.keras.layers.Layer 工作

Jam*_*mon 10 python keras tensorflow

我想知道是否有人知道该build()函数tf.keras.layers.Layer在幕后的类中是如何工作的。根据文档

当您知道输入张量的形状并可以完成其余的初始化工作时,将调用 build

所以对我来说,这门课的行为似乎与此类似:

class MyDenseLayer:
  def __init__(self, num_outputs):
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]), self.num_outputs])

  def __call__(self, input):
    self.build(input.shape) ## build is called here when input shape is known
    return tf.matmul(input, self.kernel)
Run Code Online (Sandbox Code Playgroud)

我无法想象build()会被永远调用__call__,但它是唯一传入输入的地方。有谁知道这到底是如何工作的?

Jak*_*kub 6

Layer.build()方法通常用于实例化层的权重。有关示例,请参阅源代码tf.keras.layers.Dense,并注意权重和偏置张量是在该函数中创建的。该Layer.build()方法接受一个input_shape参数,权重和偏差的形状通常取决于输入的形状。

Layer.call()另一方面,该方法实现了层的前向传递。您不想覆盖__call__,因为它是在基类中实现的tf.keras.layers.Layer。在自定义层中,您应该实现call().

Layer.call()不叫Layer.build()。但是,如果层尚未构建(source),Layer().__call__() 调用它,这将设置一个属性以防止再次调用。换句话说,只在第一次被调用时调用。self.built = TrueLayer.build()Layer.__call__()Layer.build()