Tensorflow: create tf.NodeDef() and set attributes

mrg*_*oom 8 python protocol-buffers tensorflow

I'm trying to create a new node and set its attributes.

For example printing one of the graph nodes I see that its attributes are:

attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
Run Code Online (Sandbox Code Playgroud)

I can create a node like:

node = tf.NodeDef(name='MyConstTensor', op='Const',
                   attr={'value': tf.AttrValue(tensor=tensor_proto),
                         'dtype': tf.AttrValue(type=dt)})
Run Code Online (Sandbox Code Playgroud)

但是如何添加key: "T"属性?即tf.AttrValue在这种情况下里面应该是什么?

查看attr_value.proto我已经尝试过:

node = tf.NodeDef()
node.name = 'MySub'
node.op = 'Sub'
node.input.extend(['MyConstTensor', 'conv2'])
node.attr["key"].s = 'T' # TypeError: 'T' has type str, but expected one of: bytes
Run Code Online (Sandbox Code Playgroud)

更新:

我发现在Tensorflow中应该这样写:

node.attr["T"].type = b'float32'
Run Code Online (Sandbox Code Playgroud)

但这给出了一个错误:

TypeError:b'float32'具有字节类型,但应为以下类型之一:int,long

而且我不确定哪个int值对应于float32。

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto#L23

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto#L35

mrg*_*oom 3

通过反复试验,我发现这只是:

node.attr["T"].type = 1 # to set type to float32
Run Code Online (Sandbox Code Playgroud)