查找 tensorflow op 依赖的所有变量

bla*_*dog 3 tensorflow

有没有办法找到给定操作(通常是损失)所依赖的所有变量?我想使用它然后将此集合传递给optimizer.minimize()tf.gradients()使用各种set().intersection()组合。

到目前为止,我已经找到op.op.inputs并尝试了一个简单的 BFS,但我从来没有遇到Variabletf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)或返回的对象slim.get_variables()

相应的“Tensor.op._id andVariables.op._id”字段之间似乎确实存在对应关系,但我不确定这是我应该依赖的东西。

或者,也许我一开始就不应该这样做?我当然可以在构建图形时精心构建不相交的变量集,但是如果我更改模型,很容易遗漏一些东西。

mrr*_*rry 5

文档tf.Variable.op不是特别清楚,但它确实提到了tf.Operation实现 a 中tf.Variable使用的关键:任何依赖于 a 的操作tf.Variable都将在该操作的路径上。由于该tf.Operation对象是可散列的,因此您可以将其用作dicttf.Operation对象映射到相应tf.Variable对象的 a 的键,然后像以前一样执行 BFS:

op_to_var = {var.op: var for var in tf.trainable_variables()}

starting_op = ...
dependent_vars = []

queue = collections.deque()
queue.append(starting_op)

visited = set([starting_op])

while queue:
  op = queue.popleft()
  try:
    dependent_vars.append(op_to_var[op])
  except KeyError:
    # `op` is not a variable, so search its inputs (if any). 
    for op_input in op.inputs:
      if op_input.op not in visited:
        queue.append(op_input.op)
        visited.add(op_input.op)
Run Code Online (Sandbox Code Playgroud)