等效于 tensorflow 中的 enumerate 以在 tf.map_fn 中使用索引

klu*_*luu 3 python tensorflow

我有一个张量,我使用tf.map_fn. 现在我想将索引作为参数包含在我传递给的函数中tf.map_fn。在 numpy 中,我可以enumerate用来获取该信息并将其传递到我的 lambda 函数中。这是 numpy 中的一个示例,我将 0 添加到第一行,将 1 添加到第二行,依此类推:

a = np.array([[2, 1], [4, 2], [-1, 2]])

def some_function(x, i):
    return x + i

res = map(lambda (i, row): some_function(row, i), enumerate(a))
print(res)

> [array([2, 1]), array([5, 3]), array([1, 4])]
Run Code Online (Sandbox Code Playgroud)

我一直无法enumerate在 tensorflow 中找到等效项,我不知道如何在 tensorflow 中获得相同的结果。有人知道用什么让它在 tensorflow 中工作吗?这是一个示例代码,我在其中的每一行中添加 1 a

import tensorflow as tf

a = tf.constant([[2, 1], [4, 2], [-1, 2]])

with tf.Session() as sess:
    res = tf.map_fn(lambda row: some_function(row, 1), a)
    print(res.eval())

> [[3 2]
   [5 3]
   [0 3]]
Run Code Online (Sandbox Code Playgroud)

感谢任何可以帮助我解决这个问题的人。

ben*_*che 6

tf.map_fn()可以有很多输入/输出。因此,您可以使用tf.range()来构建行索引张量并将其用于:

import tensorflow as tf

def some_function(x, i):
    return x + i

a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)

res, _ = tf.map_fn(lambda x: (some_function(x[0], x[1]), x[1]), 
                   (a, a_rows), dtype=(tf.int32, tf.int32))

with tf.Session() as sess:
    print(res.eval())
    # [[2 1]
    #  [5 3]
    #  [1 4]]
Run Code Online (Sandbox Code Playgroud)

注意:在许多情况下,“逐行处理矩阵”可以立即完成,例如通过广播,而不是使用循环:

import tensorflow as tf

a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res = a + a_rows

with tf.Session() as sess:
    print(res.eval())
    # [[2 1]
    #  [5 3]
    #  [1 4]]
Run Code Online (Sandbox Code Playgroud)