在多个线程中重用 Tensorflow 会话会导致崩溃

And*_*nak 6 python multithreading python-multithreading python-3.x tensorflow

背景:

我有一些复杂的强化学习算法,我想在多个线程中运行。

问题

尝试调用sess.run线程时,我收到以下错误消息:

RuntimeError: The Session graph is empty. Add operations to the graph before calling run().

重现错误的代码:

import tensorflow as tf

import threading

def thread_function(sess, i):
    inn = [1.3, 4.5]
    A = tf.placeholder(dtype=float, shape=(None), name="input")
    P = tf.Print(A, [A])
    Q = tf.add(A, P)
    sess.run(Q, feed_dict={A: inn})

def main(sess):

    thread_list = []
    for i in range(0, 4):
        t = threading.Thread(target=thread_function, args=(sess, i))
        thread_list.append(t)
        t.start()

    for t in thread_list:
        t.join()

if __name__ == '__main__':

    sess = tf.Session()
    main(sess)
Run Code Online (Sandbox Code Playgroud)

如果我在线程外运行相同的代码,它可以正常工作。

有人可以就如何使用 python 线程正确使用 Tensorflow 会话提供一些见解吗?

de1*_*de1 8

Session 不仅可以是当前线程的默认值,还可以是图。当您传入会话​​并调用run它时,默认图形将是不同的。

您可以像这样修改您的thread_function以使其工作:

def thread_function(sess, i):
    with sess.graph.as_default():
        inn = [1.3, 4.5]
        A = tf.placeholder(dtype=float, shape=(None), name="input")
        P = tf.Print(A, [A])
        Q = tf.add(A, P)
        sess.run(Q, feed_dict={A: inn})
Run Code Online (Sandbox Code Playgroud)

但是,我不希望有任何显着的加速。Python 线程在其他一些语言中并不是它的意思,只有某些操作(例如 io)会并行运行。对于 CPU 繁重的操作,它不是很有用。多处理可以真正并行运行代码,但您不会共享同一个会话。