即使使用随机种子也无法使用 Tensorflow 重现结果

Jan*_*lly 5 python random-seed deep-learning keras tensorflow

我正在用我生成的数据在 Keras 中训练一个简单的自动编码器。我目前正在 Google Colab 笔记本中运行代码(以防万一可能相关)。为了获得可重复的结果,我目前正在设置如下所示的随机种子,但它似乎并不完全有效:

# Choose random seed value 
seed_value = 0

# Set numpy pseudo-random generator at a fixed value
np.random.seed(seed_value)

# Set tensorflow pseudo-random generator at a fixed value
import tensorflow as tf
tf.random.set_seed(seed_value)
Run Code Online (Sandbox Code Playgroud)

每次初始化模型时,随机种子代码似乎有助于获得相同的初始权重。我可以model.get_weights()在创建模型后看到它的使用(即使我重新启动笔记本并重新运行代码也是如此)。但是,我无法在模型性能方面获得可重复的结果,因为每次训练后模型权重都不同。我假设上面的随机种子代码确保数据在训练期间每次以相同的方式拆分和混洗,即使我没有事先拆分训练/验证数据(我使用validation_split=0.2)或指定shuffle=False在拟合模型时,但也许我做出这个假设是不正确的?此外,我是否需要包含任何其他随机种子以确保可重复的结果?这是我用来构建和训练模型的代码:

def construct_autoencoder(input_dim, encoded_dim):
   # Add input
   input = Input(shape=(input_dim,))

   # Add encoder layer
   encoder = Dense(encoded_dim, activation='relu')(input)

   # Add decoder layer
   # Input contains binary values, hence the sigmoid activation
   decoder = Dense(input_dim, activation='sigmoid')(encoder)
   model = Model(inputs=input, outputs=decoder)

   return model

autoencoder = construct_autoencoder(10, 6)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
# print(autoencoder.get_weights()) -> This is the same every time, even with restarting the notebook

autoencoder.fit([data,
                 data, 
                 epochs=20, 
                 validation_split=0.2,
                 batch_size=16,
                 verbose=0)

# print(autoencoder.get_weights()) -> This is different every time, but not sure why?
Run Code Online (Sandbox Code Playgroud)

如果您对为什么我在模型训练期间没有获得可重复的结果有任何想法,请告诉我。我在 Keras 的网站上找到了这个https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development,但不确定它是否与此相关(如果是,为什么?)。我知道还有其他问题询问模型训练的可重复性,但我没有找到任何解决这个特定问题的问题。非常感谢!

Rob*_*Rob 0

除了设置 Keras 文章中的种子和建议(它们确实相关)之外,您还需要确保所有 python 模块的版本都与笔记本中的相同。

pip freeze使用命令(在命令行界面中)可以轻松地在本地检查所有模块的版本。可以通过以下方式逐个模块地在笔记本中进行检查:

import tensorflow as tf
print(tf.__version__)
Run Code Online (Sandbox Code Playgroud)