Model call() 和 train_step() 何时被调用?

kir*_*off 9 python machine-learning tensorflow

我正在阅读本教程,了解如何自定义训练循环

https://colab.research.google.com/github/tensorflow/docs/blob/snapshot-keras/site/en/guide/keras/customizing_what_happens_in_fit.ipynb#scrollTo=46832f2077ac

最后一个示例显示了通过自定义训练实现的 GAN,其中仅定义了__init__train_step和方法compile

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        ...
Run Code Online (Sandbox Code Playgroud)

如果我的模型也有自定义函数会怎样call()?是否train_step()覆盖call()call()和不是train_step()都被称为fit()and 两者有什么区别?

下面“我”写了另一段代码,我想知道什么被称为fit(),call()train_step()

class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   return_sequences=True,
                                   return_state=True,
                                   reset_after=True
                                   )
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

  @tf.function
  def train_step(self, inputs):
    # unpack the data
    inputs, labels = inputs
  
    with tf.GradientTape() as tape:
      predictions = self(inputs, training=True) # forward pass
      # Compute the loss value
      # (the loss function is configured in `compile()`)
      loss=self.compiled_loss(labels, predictions, regularization_losses=self.losses)

    # compute the gradients
    grads=tape.gradient(loss, model.trainable_variables)
    # Update weights
    self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(labels, predictions)

    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}
Run Code Online (Sandbox Code Playgroud)

xdu*_*ch0 15

这些是不同的概念,用法如下:

  • train_step被称为fit. 基本上,fit循环数据集并提供每个批次train_step(当然,然后处理指标、簿记等)。
  • call当您调用模型时使用。准确地说,model(inputs)在您的情况下编写或self(inputs)将使用该函数__call__,但该类Model定义了该函数,以便它将依次使用该函数call

这些是技术方面。直观地说:

  • call应该定义模型的前向传播。即输入如何转换为输出。
  • train_step定义训练步骤的逻辑,通常采用梯度下降。它经常被利用,call因为训练步骤往往包括模型的前向传递来计算梯度。

至于您链接的 GAN 教程,我想说它实际上可以被认为是不完整的。它无需定义即可工作,call因为自定义显式train_step调用生成器/鉴别器字段(因为这些是预定义模型,因此可以像往常一样调用它们)。如果您尝试像这样调用 GAN 模型gan(inputs),我会假设您收到一条错误消息(我没有对此进行测试)。gan.generator(inputs)因此,例如,您总是必须调用来生成。

最后(这部分可能有点令人困惑),请注意,您可以对 a 进行子类化Model来定义自定义训练步骤,然后通过功能 API 对其进行初始化(如model = Model(inputs, outputs)),在这种情况下,您可以在训练步骤中使用,而call无需使用你自己定义它,因为函数式 API 会处理这个问题。