Van*_*nel 5 python keras tensorflow
我正在将 TensorFlow 代码迁移到 Tensorflow 2.1.0。
这是原始代码:
conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)
Run Code Online (Sandbox Code Playgroud)
这就是我所做的:
conv1 = Conv2D(out_channels, (3, 3), activation='relu', padding='same', data_format='channels_last', name=name)(inputs)
conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', data_format="channels_last")(conv1)
#conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
pool1 = MaxPooling2D(pool_size=(2, 2), data_format="channels_last")(conv1)
Run Code Online (Sandbox Code Playgroud)
我的问题是我不知道该怎么办tf.contrib.layers.batch_norm。
如何迁移tf.contrib.layers.batch_norm到 Tensorflow 2.x?
更新:
使用评论建议,我认为我已经正确迁移:
conv1 = BatchNormalization(momentum=0.99, scale=True, center=True)(conv1)
Run Code Online (Sandbox Code Playgroud)
但我不确定是否decay是这样momentum,我不知道如何updates_collections在BatchNormalization方法中设置。
小智 2
我在使用经过训练的模型进行微调时遇到了这个问题。tf.contrib.layers.batch_norm只是用类似 OP 的替换tf.keras.layers.BatchNormalization确实给了我一个错误,其修复如下所述。
旧代码如下所示:
tf.contrib.layers.batch_norm(
tensor,
scale=True,
center=True,
is_training=self.use_batch_statistics,
trainable=True,
data_format=self._data_format,
updates_collections=None,
)
Run Code Online (Sandbox Code Playgroud)
更新后的工作代码如下所示:
tf.keras.layers.BatchNormalization(
name="BatchNorm",
scale=True,
center=True,
trainable=True,
)(tensor)
Run Code Online (Sandbox Code Playgroud)
我不确定我删除的所有关键字参数是否都会成为问题,但一切似乎都有效。注意name="BatchNorm"论证。这些图层使用不同的命名模式,因此我必须使用该inspect_checkpoint.py工具来查看模型并找到恰好是的图层名称BatchNorm。
| 归档时间: |
|
| 查看次数: |
4986 次 |
| 最近记录: |