tensorflow 在 python for 循环中运行速度非常慢

0 tensorflow

以下代码使用 tensorflow 库,与 numpy 库相比,它的运行速度要慢得多。我知道我正在调用一个在 python for 循环中使用 tensorflow 库的函数(稍后我将与 python 多处理并行),但代码按原样运行,运行速度非常慢。

有人可以帮助我如何使此代码运行得更快吗?谢谢。


from math import *
import numpy as np
import sys
from multiprocessing import Pool
import tensorflow as tf

def Trajectory_Fun(tspan, a, b, session=None, server=None):
    # Open tensorflow session
    if session==None:
        if server==None:
            sess = tf.Session()
        else:
            sess = tf.Session(server.target)       
    else:
        sess = session
    B = np.zeros(np.size(tspan), dtype=np.float64)
    B[0] = b
    for i, t in enumerate(tspan):
        r = np.random.rand(1)
        if r>a:
            c = sess.run(tf.trace(tf.random_normal((4, 4), r, 1.0))) 
        else:
            c = 0.0 # sess.run(tf.trace(tf.random_normal((4, 4), 0.0, 1.0)))
        B[i] = c
    # Close tensorflow session
    if session==None:
        sess.close()
    return B

def main(argv):
    # Parameters
    tspan = np.arange(0.0, 1000.0)
    a = 0.1
    b = 0.0
    # Run test program
    B = Trajectory_Fun(tspan, a, b, None, None)
    print 'Done!'

if __name__ == "__main__":
    main(sys.argv[1:])
Run Code Online (Sandbox Code Playgroud)

mrr*_*rry 5

如您的问题所述,该程序的性能会很差,因为它会为每个操作创建几个新的 TensorFlow 图节点。TensorFlow 中的基本假设是(大约)您将构建一次图,然后多次调用sess.run()它(的各个部分)。第一次运行图的成本相对较高,因为 TensorFlow 必须构建各种数据结构并优化图在多个设备上的执行。但是,TensorFlow 缓存了这项工作,因此后续使用要便宜得多。

您可以通过构建一次图形并使用(例如)一个tf.placeholder()操作来输入每次迭代中发生变化的值,从而使该程序更快。例如,以下应该可以解决问题:

B = np.zeros(np.size(tspan), dtype=np.float64)
B[0] = b

# Define the TensorFlow graph once and reuse it in each iteration of the for loop.
r_placeholder = tf.placeholder(tf.float32, shape=[])
out_t = tf.trace(tf.random_normal((4, 4), r_placeholder, 1.0))

with tf.Session() as sess:
  for i, t in enumerate(tspan):
    r = np.random.rand(1)
    if r > a:
      c = sess.run(out_t, feed_dict={r_placeholder: r})
    else:
      c = 0.0
    B[i] = c
  return B
Run Code Online (Sandbox Code Playgroud)

您可以通过使用 TensorFlow 循环并减少对 的调用来提高效率sess.run(),但一般原则是相同的:多次重用相同的图以获得 TensorFlow 的好处。