如何将keras中的参数设置为不可训练?

Swi*_*son 3 python deep-learning keras

我是Keras的新手,正在建立模型。我想在训练前几层时冻结模型最后几层的权重。我试图将横向模型的可训练属性设置为False,但似乎不起作用。这是代码和模型摘要:

opt = optimizers.Adam(1e-3)
domain_layers = self._build_domain_regressor()
domain_layers.trainble = False
feature_extrator = self._build_common()
img_inputs = Input(shape=(160, 160, 3))
conv_out = feature_extrator(img_inputs)
domain_label = domain_layers(conv_out)
self.domain_regressor = Model(img_inputs, domain_label)
self.domain_regressor.compile(optimizer = opt, loss='binary_crossentropy', metrics=['accuracy'])
self.domain_regressor.summary()
Run Code Online (Sandbox Code Playgroud)

模型摘要:模型摘要

如您所见,model_1是可训练的。但是根据代码,它被设置为不可训练。

Swi*_*son 8

单词“trainble”中有一个拼写错误(缺少一个“a”)。可悲的是,keras 没有警告我该模型没有“trainble”属性。这个问题可以结束了。


grw*_*wlf 8

尽管原始问题的解决方案是一个错字修复,但让我添加一些关于 keras 可训练的信息。

现代 Keras 包含以下用于查看和操作可训练状态的工具:

  • tf.keras.Layer._get_trainable_state()函数 - 打印字典,其中键是模型组件,值是布尔值。请注意,这tf.keras.Model也是一个tf.Keras.Layer.
  • tf.keras.Layer.trainable 属性 - 操纵各个层的可训练状态。

因此,典型的操作如下所示:

# Print current trainable map:
print(model._get_trainable_state())

# Set every layer to be non-trainable:
for k,v in model._get_trainable_state().items():
    k.trainable = False

# Don't forget to re-compile the model
model.compile(...)
Run Code Online (Sandbox Code Playgroud)


Gee*_*ode 6

您可以简单地将一个布尔值分配给layer属性trainable

model.layers[n].trainable = False
Run Code Online (Sandbox Code Playgroud)

您可以可视化可训练的图层:

for l in model.layers:
    print(l.name, l.trainable)
Run Code Online (Sandbox Code Playgroud)

您也可以通过模型定义传递它:

frozen_layer = Dense(32, trainable=False)
Run Code Online (Sandbox Code Playgroud)

从Keras 文档中

To "freeze" a layer means to exclude it from training, i.e. its weights will never be updated. This is useful in the context of fine-tuning a model, or using fixed embeddings for a text input.
You can pass a trainable argument (boolean) to a layer constructor to set a layer to be non-trainable. Additionally, you can set the trainable property of a layer to True or False after instantiation. For this to take effect, you will need to call compile() on your model after modifying the trainable property.

  • 请注意,您不能简单地在“层”上设置 `trainable=False`,您必须从您将编译的“模型实例”中获取层:`self.domain_regressor`。您可能需要递归层查找,因为模型中有模型。 (4认同)
  • 它不会真的被冻结,至少在我的经验中(也许新版本的工作方式不同)。如果您没有从您正在编译的“确切模型”中获得图层,则可训练属性将无法真正起作用。 (2认同)