Ale*_*exR 6 machine-learning deep-learning tensorflow recurrent-neural-network
我正在尝试制作一个Tensorflow图,其中图的一部分已经预先训练并在预测模式下运行,而其余的训练.我已经定义了我预先训练好的细胞:
rnn_cell = tf.contrib.rnn.BasicLSTMCell(100)
state0 = tf.Variable(pretrained_state0,trainable=False)
state1 = tf.Variable(pretrained_state1,trainable=False)
pretrained_state = [state0, state1]
outputs, states = tf.contrib.rnn.static_rnn(rnn_cell,
data_input,
dtype=tf.float32,
initial_state = pretrained_state)
Run Code Online (Sandbox Code Playgroud)
设置初始变量trainable=False没有帮助.这些仅用于初始化权重,因此权重仍然会发生变化.
我仍然需要在训练步骤中运行优化器,因为我的模型的其余部分需要训练.但是,如何防止优化器更改此rnn单元格中的权重?
是否有rnn_cell相当于trainable=False?
您可以使用 来tf.stop_gradient()防止pretrained图的各个部分更新其权重,也可以使用 来optimiser()指定应训练图的哪些部分。第二种方法涉及:
#Create variable scope for the trainable parts of the graph: tf.variable_scope('train').
# get trainable variables
t_vars = tf.trainable_variables()
train_vars = [var for var in t_vars if var.name.startswith('train')]
# train only the variables of a particular scope
opt = optimizer.minimize(cost, var_list=train_vars)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
921 次 |
| 最近记录: |