Model.trainable = False 与 Model.compile()

xxb*_*iao 6 python keras

这些说法正确吗?

  • Model.trainable = False 除非编译发生,否则它本身绝对没有影响(对任何编译的东西)。
  • 如果我取ModelA已编译的两层( ModelA.compile(...)),创建一个跳过模型ModelB=Model(intermediate_layer1, intermediate_layer2)并设置ModelB.trainable=False, ModelB.compile(...),则不会有任何变化ModelA;假设 trainable 没有被触及,ModelA如果只训练 ModelA ,那么所有的权重都会更新 ( ModelA.fit(...))
  • 这仅适用于重量更新,因此重量将被保存/加载而不会出现问题(即使它是错误的重量)。

这一切都始于我尝试训练 GAN 时,在训练生成器时冻结鉴别器并收到此警告:

 UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
Run Code Online (Sandbox Code Playgroud)

我调查了这个,发现人们也调查了这个:

https://github.com/keras-team/keras/issues/8585

这是改编自该问题线程的可重现示例:

# making discriminator
d_input = Input(shape=(2,))
d_output = Activation('softmax')(Dense(2)(d_input))
discriminator = Model(inputs=d_input, outputs=d_output)
discriminator.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])

# making generator
g_input = Input(shape=(2,))
g_output = Activation('relu')(Dense(2)(g_input))
generator = Model(inputs=g_input, outputs=g_output)

# making gan(generator -> discriminator)
discriminator.trainable = False # CHECK THIS OUT!
gan = Model(inputs=g_input, outputs=discriminator(g_output))
gan.compile(loss='categorical_crossentropy', optimizer='adam')

# training
BATCH_SIZE = 3
some_input_data = np.array([[1,2],[3,4],[5,6]])
some_target_data = np.array([[1,1],[2,2],[3,3]])
# update discriminator
generated = generator.predict(some_input_data, verbose=0)
X = np.concatenate((some_target_data, generated), axis=0)
y = [[0,1]]*BATCH_SIZE + [[1,0]]*BATCH_SIZE
d_metrics = discriminator.train_on_batch(X, y)
# update generator
g_metrics = gan.train_on_batch(some_input_data, [[0,1]]*BATCH_SIZE)
# loop these operations for batches...
Run Code Online (Sandbox Code Playgroud)

当有人说这是错误警告而有人说权重可能会搞砸时,我感到很困惑。

然后我读到这个问题:不应该在模型下 model.trainable=False 冻结权重吗?

这篇文章很好地解释了“可训练”的实际作用。我想知道我的理解是否正确,并确保我的 GAN 训练正确。