kir*_*off 9 python machine-learning tensorflow
我正在阅读本教程,了解如何自定义训练循环
最后一个示例显示了通过自定义训练实现的 GAN,其中仅定义了__init__、train_step和方法compile
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
...
Run Code Online (Sandbox Code Playgroud)
如果我的模型也有自定义函数会怎样call()?是否train_step()覆盖call()?call()和不是train_step()都被称为fit()and 两者有什么区别?
下面“我”写了另一段代码,我想知道什么被称为fit(),call()或train_step():
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True,
reset_after=True
)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
@tf.function
def train_step(self, inputs):
# unpack the data
inputs, labels = inputs
with tf.GradientTape() as tape:
predictions = self(inputs, training=True) # forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss=self.compiled_loss(labels, predictions, regularization_losses=self.losses)
# compute the gradients
grads=tape.gradient(loss, model.trainable_variables)
# Update weights
self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(labels, predictions)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
Run Code Online (Sandbox Code Playgroud)
xdu*_*ch0 15
这些是不同的概念,用法如下:
train_step被称为fit. 基本上,fit循环数据集并提供每个批次train_step(当然,然后处理指标、簿记等)。call当您调用模型时使用。准确地说,model(inputs)在您的情况下编写或self(inputs)将使用该函数__call__,但该类Model定义了该函数,以便它将依次使用该函数call。这些是技术方面。直观地说:
call应该定义模型的前向传播。即输入如何转换为输出。train_step定义训练步骤的逻辑,通常采用梯度下降。它经常被利用,call因为训练步骤往往包括模型的前向传递来计算梯度。至于您链接的 GAN 教程,我想说它实际上可以被认为是不完整的。它无需定义即可工作,call因为自定义显式train_step调用生成器/鉴别器字段(因为这些是预定义模型,因此可以像往常一样调用它们)。如果您尝试像这样调用 GAN 模型gan(inputs),我会假设您收到一条错误消息(我没有对此进行测试)。gan.generator(inputs)因此,例如,您总是必须调用来生成。
最后(这部分可能有点令人困惑),请注意,您可以对 a 进行子类化Model来定义自定义训练步骤,然后通过功能 API 对其进行初始化(如model = Model(inputs, outputs)),在这种情况下,您可以在训练步骤中使用,而call无需使用你自己定义它,因为函数式 API 会处理这个问题。
| 归档时间: |
|
| 查看次数: |
4087 次 |
| 最近记录: |