Tensorflow:如何创建 const 节点并添加到图形中?

mrg*_*oom 3 python tensorflow

我有存储在 中的现有模型.pb,我想通过 op 将一些常量张量(最初我在 numpy 数组中有数据)添加到图中的某个张量Add,如何完成?

据我了解,图中的每个节点都是,所以我需要为optensorflow.core.framework.node_def_pb2.NodeDef创建一个节点,为 const 张量创建一个节点?Add

这些是相关问题:

TensorFlow手动构建GraphDef

如何从 Tensorflow 中的 .pb 模型获取权重

jde*_*esa 6

编辑:下面的答案只是向图表添加一个断开连接的常量,如果您想添加一个新的添加操作,那么它将是这样的:

import tensorflow as tf

constant_value = ...
with tf.Graph().as_default():
    gd = tf.GraphDef()
    with open('my_graph.pb', 'rb') as f:
        gd.MergeFromString(f.read())
    my_tensor = tf.import_graph_def(gd, name='', return_elements='SomeOperation:0')
    tf.add(my_tensor, constant_value, name='NewOperation')
    tf.train.write_graph(tf.get_default_graph(), '.',
                         'my_modified_graph.pb', as_text=False)
Run Code Online (Sandbox Code Playgroud)

但请注意,这只是添加新操作,它不会修改原始张量的值。我不确定其中哪一个是您想要的。


最实用的方法是导入图形,添加常量并再次保存:

import tensorflow as tf

new_constant = ...
with tf.Graph().as_default():
    gd = tf.GraphDef()
    with open('my_graph.pb', 'rb') as f:
        gd.MergeFromString(f.read())
    tf.import_graph_def(gd, name='')
    tf.constant(new_constant, name='NewConstant')
    tf.train.write_graph(tf.get_default_graph(), '.',
                         'my_graph_with_constant.pb', as_text=False)
Run Code Online (Sandbox Code Playgroud)

如果由于某种原因您不想导入图表,您可以手动构建节点对象,如下所示:

import numpy as np
import tensorflow as tf
from tensorflow.core.framework.tensor_pb2 import TensorProto
from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto

# New constant to add
my_value = np.array([[1, 2, 3], [4, 5,6]], dtype=np.int32)
# Make graph node
tensor_content = my_value.tobytes()
dt = tf.as_dtype(my_value.dtype).as_datatype_enum
tensor_shape = TensorShapeProto(dim=[TensorShapeProto.Dim(size=s) for s in my_value.shape])
tensor_proto = TensorProto(tensor_content=tensor_content,
                           tensor_shape=tensor_shape,
                           dtype=dt)
node = tf.NodeDef(name='MyConstant', op='Const',
                  attr={'value': tf.AttrValue(tensor=tensor_proto),
                        'dtype': tf.AttrValue(type=dt)})
# Read existing graph
gd = tf.GraphDef()
with open('my_graph.pb', 'rb') as f:
    gd.MergeFromString(f.read())
# Add new node
gd.node.extend([node])
# Save modified graph
tf.train.write_graph(tf.get_default_graph(), '.',
                     'my_graph_with_constant.pb', as_text=False)
Run Code Online (Sandbox Code Playgroud)

请注意,这种情况相对容易,因为该节点未连接到任何其他节点。