如何使用 tf.Print 打印张量的一部分?

use*_*379 5 tensorflow

在 TensorFlow 中,当张量具有大维度时,出于调试目的仅打印张量的一部分是有用的,例如二维矩阵的对角线。我只知道如何打印整个张量如下:

sess = tf.InteractiveSession()
a = tf.constant(1.0, shape=[1000, 1000])
a = tf.Print(a, [a], "print entire a\n", summarize=1000000)
b = a + 1.
ret = sess.run(b)
Run Code Online (Sandbox Code Playgroud)

上面的代码将打印整个“a”张量。但我不确定如何打印“a”的一部分。例如,如果我只想打印 a[0,0] 而不执行 sess.run(a),则以下代码将不起作用:

sess = tf.InteractiveSession()
a = tf.constant(1.0, shape=[1000, 1000])
a[0,0] = tf.Print(a[0,0], [a[0,0]], "print part of a\n", summarize=1000000)
b = a + 1.
ret = sess.run(b)    
Run Code Online (Sandbox Code Playgroud)

Jul*_*yes 2

从文档中:

Print(
    input_,
    data,
    message=None,
    first_n=None,
    summarize=None,
    name=None
)
Run Code Online (Sandbox Code Playgroud)

打印张量列表。这是一个身份操作,在评估时会产生打印的副作用data

a应保持不变,并且在data参数中应放置需要打印的张量。

import tensorflow as tf

sess = tf.InteractiveSession()
a = tf.constant(1.0, shape=[1000, 1000])
a = tf.Print(a, [a[0, 0]], "Print part of a\n", summarize=100000)
b = a + 1.
ret = sess.run(b)
Run Code Online (Sandbox Code Playgroud)