Tensorflow张量值图

zha*_*ulu 5 python machine-learning tensorflow

我的问题是如何用字典映射张量?例如像这样:

dict = {1:3, 2:4}
origin_tensor = tf.Variable([1,2,1], tf.int32)
Run Code Online (Sandbox Code Playgroud)

字典很大。现在,如何根据 dict 制作地图选项以将张量映射到 tf.Variable([3,4,3], tf.int32) ?

更重要的是,映射时无法使用 .eval() ,您可以认为 origin_tensor 是来自批读取器的标签张量。

小智 1

您可以使用tf.map_fn() 函数。由于您描述的情况显示了 x=x+1 的连接,因此在 Tensorflow 中可以将其解释为:

elems = np.array([1, 2])
plus_one = tf.map_fn(lambda x: x + 1, elems)
# plus_one == [3, 4]
Run Code Online (Sandbox Code Playgroud)

  • 更一般的映射怎么样?例如,`m={0:0, 1:2, 2:2, 3:2, 4:2, 5:3, 6:1}`。我们可以使用 dict 来实现类似 tf.map_fn(lambda x: m[x], elems) 的功能吗? (8认同)