Atu*_*tul 10 python deep-learning keras tensorflow tf.keras
在使用 Keras 子类化 API 创建模型时,我们编写了一个自定义模型类并定义了一个名为call(self, x)(主要用于编写前向传递)的函数,该函数需要一个输入。但是,此方法永远不会被调用,而是作为call传递给此类的对象,而不是将输入传递给model(images)。
model当我们__call__在类中没有实现 Python 特殊方法时,我们如何调用这个对象并传递值
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# Create an instance of the model
model = MyModel()
Run Code Online (Sandbox Code Playgroud)
使用 tf.GradientTape 训练模型:
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
Run Code Online (Sandbox Code Playgroud)
输入不应该像下面这样传递:
model = MyModel()
model.call(images)
Run Code Online (Sandbox Code Playgroud)
其实__call__方法是在Layer类中实现的,由Network类继承,由Model类继承:
class Layer(module.Module):
def __call__(self, inputs, *args, **kwargs):
class Network(base_layer.Layer):
class Model(network.Network):
Run Code Online (Sandbox Code Playgroud)
所以MyClass会继承这个__call__方法。
附加信息:
所以实际上我们所做的是覆盖继承的call方法,call然后从继承的__call__方法中调用哪个新方法。这就是为什么我们不需要做一个model.call(). 所以当我们调用我们的模型实例时,它的继承__call__方法会自动执行,它调用我们自己的call方法。
| 归档时间: |
|
| 查看次数: |
2422 次 |
| 最近记录: |