在TensorFlow中保存特定权重

luo*_*hao 5 tensorflow

在我的神经网络中,我创建了一些tf.Variable对象如下:

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}
Run Code Online (Sandbox Code Playgroud)

我将如何保存变量weightsbiases不保存其它变量的具体数量的迭代之后?

mrr*_*rry 7

在TensorFlow中保存变量的标准方法是使用tf.train.Saver对象.默认情况下,它会保存问题中的所有变量(即结果tf.all_variables()),但您可以通过将var_list可选参数传递给tf.train.Saver构造函数来有选择地保存变量:

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

# Define savers for explicit subsets of the variables.
weights_saver = tf.train.Saver(var_list=weights)
biases_saver = tf.train.Saver(var_list=biases)

# ...
# You need a TensorFlow Session to save variables.
sess = tf.Session()
# ...

# ...then call the following methods as appropriate:
weights_saver.save(sess)  # Save the current value of the weights.
biases_saver.save(sess)   # Save the current value of the biases.
Run Code Online (Sandbox Code Playgroud)

请注意,如果将字典传递给tf.train.Saver构造函数(例如问题中的weights和/或biases字典),TensorFlow将使用字典键(例如'wc1_0')作为其创建或使用的任何检查点文件中相应变量的名称.

默认情况下,或者如果将tf.Variable对象列表传递给构造函数,TensorFlow将使用该tf.Variable.name属性.

通过字典,您可以在模型之间共享检查点,Variable.name为每个变量提供不同的属性.仅当您要将创建的检查点与其他模型一起使用时,此详细信息才很重要.