为什么在 Keras 子类化 API 中,永远不会调用 call 方法,而是通过调用此类的对象来传递输入?

Atu*_*tul 10 python deep-learning keras tensorflow tf.keras

在使用 Keras 子类化 API 创建模型时,我们编写了一个自定义模型类并定义了一个名为call(self, x)(主要用于编写前向传递)的函数,该函数需要一个输入。但是,此方法永远不会被调用,而是作为call传递给此类的对象,而不是将输入传递给model(images)

model当我们__call__在类中没有实现 Python 特殊方法时,我们如何调用这个对象并传递值

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10, activation='softmax')

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)

# Create an instance of the model
model = MyModel()
Run Code Online (Sandbox Code Playgroud)

使用 tf.GradientTape 训练模型:

@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)
Run Code Online (Sandbox Code Playgroud)

输入不应该像下面这样传递:

model = MyModel()
model.call(images)
Run Code Online (Sandbox Code Playgroud)

Gee*_*ode 6

其实__call__方法是在Layer类中实现的,由Network类继承,由Model类继承:

class Layer(module.Module):
    def __call__(self, inputs, *args, **kwargs):

class Network(base_layer.Layer):

class Model(network.Network):
Run Code Online (Sandbox Code Playgroud)

所以MyClass会继承这个__call__方法。

附加信息:

所以实际上我们所做的是覆盖继承的call方法,call然后从继承的__call__方法中调用哪个新方法。这就是为什么我们不需要做一个model.call(). 所以当我们调用我们的模型实例时,它的继承__call__方法会自动执行,它调用我们自己的call方法。