过滤出张量中的非零值

use*_*674 7 python numpy python-2.7 python-3.x tensorflow

假设我有一个数组:input = np.array([[1,0,3,5,0,8,6]]),我想过滤掉[1,3,5,8,6].

我知道您可以使用tf.where条件,但返回的值仍然为0.以下代码段的输出是 [[[1 0 3 5 0 8 6]]].我也不明白为什么tf.where需要xy.

无论如何,我可以摆脱由此产生的张量中的0?

import numpy as np
import tensorflow as tf

input = np.array([[1,0,3,5,0,8,6]])

X = tf.placeholder(tf.int32,[None,7])

zeros = tf.zeros_like(X)
index = tf.not_equal(X,zeros)
loc = tf.where(index,x=X,y=X)

with tf.Session() as sess:
    out = sess.run([loc],feed_dict={X:input})
    print np.array(out)
Run Code Online (Sandbox Code Playgroud)

The*_*erd 6

首先创建一个布尔掩码来识别你的条件在哪里;然后将掩码应用于您的张量,如下所示。如果你想使用 tf.where 来索引,你可以 - 但是它使用 x&y 返回一个张量与输入具有相同的等级,所以没有进一步的工作,你可以获得的最好的结果将是 [[[1 -1 3 5 -1 8 6]]] 将 -1 更改为您稍后会确定要删除的内容。仅使用 where(不带 x&y)将为您提供条件为真的所有值的索引,因此如果您愿意,可以使用索引创建解决方案。为了最清楚,我的建议如下。

import numpy as np
import tensorflow as tf
input = np.array([[1,0,3,5,0,8,6]])
X = tf.placeholder(tf.int32,[None,7])
zeros = tf.cast(tf.zeros_like(X),dtype=tf.bool)
ones = tf.cast(tf.ones_like(X),dtype=tf.bool)
loc = tf.where(input!=0,ones,zeros)
result=tf.boolean_mask(input,loc)
with tf.Session() as sess:
 out = sess.run([result],feed_dict={X:input})
 print (np.array(out))
Run Code Online (Sandbox Code Playgroud)