什么是张量流漂浮参考?

Ano*_*112 5 python tensorflow

尝试运行以下基本示例来运行条件计算我收到以下错误消息:

'x'传递float与预期的float_ref不兼容

什么是tensorflow float_ref以及如何修改代码?

import tensorflow as tf
from tensorflow.python.ops.control_flow_ops import cond

a = tf.Variable(tf.constant(0.),name="a")
b = tf.Variable(tf.constant(0.),name="b")
x = tf.Variable(tf.constant(0.),name="x")

def add():
    x.assign( a + b)
    return x

def last():
    return x

calculate= cond(x==0.,add,last)

with tf.Session() as s:
    val = s.run([calculate], {a: 1., b: 2., x: 0.})
    print(val) # 3
    val=s.run([calculate],{a:4.,b:5.,x:val})
    print(val) # 3
Run Code Online (Sandbox Code Playgroud)

reu*_*ohn 2

float_ref这里指的是对浮点数的引用,即你的 Tensorflow float 变量x

正如此处所解释的,您面临此错误,因为您无法在同一会话运行中同时分配和传递变量作为 feed_dict,就像您在以下语句中所做的那样:

val = s.run([calculate], {a: 1., b: 2., x: 0.})
Run Code Online (Sandbox Code Playgroud)

当您解决该语句并最终得到以下结果时,情况会变得更加明显:

val = s.run([x.assign( a + b)], {a: 1., b: 2., x: 0.})
Run Code Online (Sandbox Code Playgroud)