有没有办法找到给定操作(通常是损失)所依赖的所有变量?我想使用它然后将此集合传递给optimizer.minimize()或tf.gradients()使用各种set().intersection()组合。
到目前为止,我已经找到op.op.inputs并尝试了一个简单的 BFS,但我从来没有遇到Variable过tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)或返回的对象slim.get_variables()
相应的“Tensor.op._id andVariables.op._id”字段之间似乎确实存在对应关系,但我不确定这是我应该依赖的东西。
或者,也许我一开始就不应该这样做?我当然可以在构建图形时精心构建不相交的变量集,但是如果我更改模型,很容易遗漏一些东西。
的文档tf.Variable.op不是特别清楚,但它确实提到了tf.Operation在实现 a 中tf.Variable使用的关键:任何依赖于 a 的操作tf.Variable都将在该操作的路径上。由于该tf.Operation对象是可散列的,因此您可以将其用作dict将tf.Operation对象映射到相应tf.Variable对象的 a 的键,然后像以前一样执行 BFS:
op_to_var = {var.op: var for var in tf.trainable_variables()}
starting_op = ...
dependent_vars = []
queue = collections.deque()
queue.append(starting_op)
visited = set([starting_op])
while queue:
op = queue.popleft()
try:
dependent_vars.append(op_to_var[op])
except KeyError:
# `op` is not a variable, so search its inputs (if any).
for op_input in op.inputs:
if op_input.op not in visited:
queue.append(op_input.op)
visited.add(op_input.op)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
380 次 |
| 最近记录: |