如何列出节点所依赖的所有Tensorflow变量?

Fra*_*urt 13 python tensorflow

如何列出节点所依赖的所有Tensorflow变量/常量/占位符?

示例1(添加常量):

import tensorflow as tf

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))
Run Code Online (Sandbox Code Playgroud)

我想有一个功能list_dependencies(),如:

  • list_dependencies(d) 回报 ['a', 'b']
  • list_dependencies(e) 回报 ['a', 'b', 'c']

示例2(占位符和权重矩阵之间的矩阵乘法,然后添加偏差向量):

tf.set_random_seed(1)
input_size  = 5
output_size = 3
input       = tf.placeholder(tf.float32, shape=[1, input_size], name='input')
W           = tf.get_variable(
                "W",
                shape=[input_size, output_size],
                initializer=tf.contrib.layers.xavier_initializer())
b           = tf.get_variable(
                "b",
                shape=[output_size],
                initializer=tf.constant_initializer(2))
output      = tf.matmul(input, W, name="output")
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias")

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]}))
Run Code Online (Sandbox Code Playgroud)

我想有一个功能list_dependencies(),如:

  • list_dependencies(output) 回报 ['W', 'input']
  • list_dependencies(output_bias) 回报 ['W', 'b', 'input']

Yar*_*tov 15

以下是我用于此的实用程序(来自https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py)

# computation flows from parents to children

def parents(op):
  return set(input.op for input in op.inputs)

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""

  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}


def print_tf_graph(graph):
  """Prints tensorflow graph in dictionary form."""
  for node in graph:
    for child in graph[node]:
      print("%s -> %s" % (node.name, child.name))
Run Code Online (Sandbox Code Playgroud)

这些功能适用于操作.要获得产生张量的操作t,请使用t.op.要获得由op生成的张量op,请使用op.outputs

  • 在[graph_util]中提供相关信息可能是一个好主意(https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/python/framework/graph_util_impl.py?q=file: third_party/tensorflow.*graph_util&sq = package:piper + file:// depot/google3 + -file:google3/experimental&dr&l = 110),或者通过contrib. (2认同)