tf.map_fn 是如何工作的?

gau*_*clb 5 tensorflow

看演示:

elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]

elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
# alternate == [-1, 2, -3]

elems = np.array([1, 2, 3])
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]
Run Code Online (Sandbox Code Playgroud)

第二个和第三个我看不懂。

对于第二次:我认为结果是[2, -1],因为第一次 x=np.array([1, 2, 3]) 并返回 1*2,第二次 x=np.array([ -1, 1, -1]) 并返回 1*(-1)

对于第三个:我认为结果的形状是(3,2),因为第一次x=1并返回(1,-1),第二次x=2并返回(2,-2),第三次时间x=3并返回(3,-3)。

那么map_fn是如何工作的呢?

小智 1

Tensorflow map_fn,来自文档,

映射到从维度 0 上的 elems 解压出来的张量列表上。

在这种情况下,输入张量的唯一轴 [1,2,3] 或 [-1,1,-1]。因此,运算为 1*-1,2*1 和 3*-1,结果被重新打包,得到张量形状。