use*_*184 4 python multithreading keras tensorflow
我正在尝试在Keras和TensorFlow中实现actor-critic的异步版本.我正在使用Keras作为构建我的网络层的前端(我正在使用tensorflow直接更新参数).我有一个global_model和一个主张量流会话.但是在每个线程中我创建了一个local_model从中复制参数global_model.我的代码看起来像这样
def main(args):
config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)
sess = tf.Session(config=config)
K.set_session(sess) # K is keras backend
global_model = ConvNetA3C(84,84,4,num_actions=3)
threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)]
for t in threads:
t.start()
def a3c_thread(i, sess, global_model):
K.set_session(sess) # registering a session for each thread (don't know if it matters)
local_model = ConvNetA3C(84,84,4,num_actions=3)
sync = local_model.get_from(global_model) # I get the error here
#in the get_from function I do tf.assign(dest.params[i], src.params[i])
Run Code Online (Sandbox Code Playgroud)
我收到了来自Keras的用户警告
UserWarning:默认的TensorFlow图形不是与当前在Keras中注册的TensorFlow会话相关联的图形,因此Keras无法自动初始化变量.您应该考虑通过Keras注册正确的会话
K.set_session(sess)
然后是tf.assign操作上的张量流错误,表示操作必须在同一图表上.
ValueError:Tensor("conv1_W:0",shape =(8,8,4,16),dtype = float32_ref,device =/device:CPU:0)必须与Tensor("conv1_W:0")在同一图表中, shape =(8,8,4,16),dtype = float32_ref)
我不确定出了什么问题.
谢谢
错误来自Keras,因为它tf.get_default_graph() is sess.graph正在返回False.从TF文档中,我看到它tf.get_default_graph()返回当前线程的默认图形.当我开始一个新线程并创建一个图形时,它被构建为一个特定于该线程的单独图形.我可以通过以下方式解决这个问题,
with sess.graph.as_default():
local_model = ConvNetA3C(84,84,4,3)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2064 次 |
| 最近记录: |