小编cos*_*yes的帖子

使用 Tensorflow 2 的 Keras Functional API 时传递 `training=true`

在 TF1 中以图形模式运行时,我相信在使用函数式 API 时需要连接training=Truetraining=False通过 feeddicts。在 TF2 中执行此操作的正确方法是什么?

我相信这是在使用时自动处理的tf.keras.Sequential。例如,我不需要training文档中的以下示例中指定:

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))
Run Code Online (Sandbox Code Playgroud)

我还可以假设 keras 在使用功能性 api 进行训练时会自动处理这个问题吗?这是相同的模型,使用函数 api 重写:

inputs = tf.keras.Input(shape=((28,28,1)), name="input_image")
hid = tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, …
Run Code Online (Sandbox Code Playgroud)

python keras tensorflow tf.keras tensorflow2.0

5
推荐指数
2
解决办法
3531
查看次数

标签 统计

keras ×1

python ×1

tensorflow ×1

tensorflow2.0 ×1

tf.keras ×1