Tensorflow python:访问张量中的各个元素

cip*_*r42 42 python python-2.7 tensorflow

这个问题是关于访问张量中的单个元素,比如[[1,2,3]].我需要访问内部元素[1,2,3](这可以使用.eval()或sess.run()执行)但是当张量的大小很大时需要更长时间)

有没有什么方法可以更快地做到这一点?

提前致谢.

mrr*_*rry 55

有两种主要方法可以访问张量中元素的子集,其中任何一个都应该适用于您的示例.

  1. 使用索引运算符(基于tf.slice())从张量中提取连续切片.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    
    output = input[0, :]
    print sess.run(output)  # ==> [1 2 3]
    
    Run Code Online (Sandbox Code Playgroud)

    索引运算符支持许多与NumPy相同的切片规范.

  2. 使用tf.gather()op从张量中选择一个非连续切片.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    
    output = tf.gather(input, 0)
    print sess.run(output)  # ==> [1 2 3]
    
    output = tf.gather(input, [0, 2])
    print sess.run(output)  # ==> [[1 2 3] [7 8 9]]
    
    Run Code Online (Sandbox Code Playgroud)

    请注意,tf.gather()只允许您选择第0维中的整个切片(矩阵示例中的整行),因此您可能需要tf.reshape()tf.transpose()输入以获取适当的元素.

  • "..所以你可能需要输入tf.reshape()或tf.transpose()以获得适当的元素." - >或者使用`tf.gather_nd`? (6认同)
  • “塞斯”从何而来?它存在于tensorflow 2.xx中吗? (4认同)

Pey*_*man 7

我希望我正确理解你的问题。您可以通过 访问 TensorFlow 2 中张量中的元素.numpy()

import tensorflow as tf
t = tf.constant([[1,2,3]])

print(t.numpy()[0][1]) # This will print 2
Run Code Online (Sandbox Code Playgroud)
>>> 2
Run Code Online (Sandbox Code Playgroud)


Sor*_*rin 1

我怀疑是其余的计算需要时间,而不是访问一个元素。

此外,结果可能需要从存储的任何内存中进行复制,因此,如果它位于显卡上,则需要先将其复制回 RAM,然后才能访问您的元素。如果是这种情况,您可以通过添加张量流操作来获取第一个元素并仅返回该元素来跳过它。