我可以将tf.map_fn(...)应用于多个输入/输出吗?

Dav*_*rks 8 tensorflow

a = tf.constant([[1,2,3],[4,5,6]])
b = tf.constant([True, False], dtype=tf.bool)

a.eval()
array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
b.eval()
array([ True, False], dtype=bool)
Run Code Online (Sandbox Code Playgroud)

我想将函数应用于上面的输入a,并b使用tf.map_fn.它将输入两个[1,2,3],True并输出相似的值.

假设函数只是标识:lambda(x,y): x,y因此,给定输入[1,2,3], True,它将输出相同的张量.

我知道如何使用tf.map_fn(...)一个变量,但不能使用两个变量.在这种情况下,我有混合数据类型(int32和bool)所以我不能简单地连接张量并在调用后拆分它们.

我可以使用tf.map_fn(...)不同数据类型的多个输入/输出吗?

Dav*_*rks 13

弄清楚了.你必须dtype为每个不同的张量定义每个张量的数据类型,然后你可以将张量作为元组传递,你的map函数接收一个输入元组,然后map_fn返回一个元组.

有效的示例:

a = tf.constant([[1,2,3],[4,5,6]])
b = tf.constant([True, False], dtype=tf.bool)

c = tf.map_fn(lambda x: (x[0], x[1]), (a,b), dtype=(tf.int32, tf.bool))

c[0].eval()
array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
c[1].eval()
array([ True, False], dtype=bool)
Run Code Online (Sandbox Code Playgroud)

  • 请注意,如果使用此功能,则处理将在CPU中完成,而不是在GPU上完成.在GPU上进行训练时,这对速度尤其有害. (5认同)
  • 你怎么知道这个特定的函数使用的是 CPU 而不是 GPU?是由于函数 tf.map_fn 本身还是 dtype 规范? (2认同)
  • 几年后重新审视这个问题,我不太确定那个评论。如果 lambda 只返回 TF 操作,我不确定它们是否会下降到 CPU,我不确定从那时到现在 TF 中是否有任何变化。如果有人想通过跟踪或转储张量连接到哪些设备来确认这将是有用的。似乎如果 `fn` 只返回 TF ops,那些应该在默认设备上运行。 (2认同)