如何基于张量流中的一些谓词从队列中过滤张量?

jra*_*ary 4 tensorflow

如何使用谓词函数过滤存储在队列中的数据?例如,假设我们有一个存储功能和标签张量的队列,我们​​只需要满足谓词的那些.我尝试了以下实现但没有成功:

feature, label = queue.dequeue()
if (predicate(feature, label)):
    enqueue_op = another_queue.enqueue(feature, label)
Run Code Online (Sandbox Code Playgroud)

dga*_*dga 8

最直接的方法是将批处理出列,通过谓词测试运行它们,用于tf.where生成与谓词匹配的密集向量,并用于tf.gather收集结果,并将该批次排入队列.如果您希望自动执行此操作,则可以在第二个队列上启动队列运行程序 - 最简单的方法是使用tf.train.batch:

例:

import numpy as np
import tensorflow as tf

a = tf.constant(np.array([5, 1, 9, 4, 7, 0], dtype=np.int32))

q = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
enqueue = q.enqueue_many([a])
dequeue = q.dequeue_many(6)
predmatch = tf.less(dequeue, [5])
selected_items = tf.reshape(tf.where(predmatch), [-1])
found = tf.gather(dequeue, selected_items)

secondqueue = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
enqueue2 = secondqueue.enqueue_many([found])
dequeue2 = secondqueue.dequeue_many(3) # XXX, hardcoded

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  sess.run(enqueue)  # Fill the first queue
  sess.run(enqueue2) # Filter, push into queue 2
  print sess.run(dequeue2) # Pop items off of queue2
Run Code Online (Sandbox Code Playgroud)

谓词产生一个布尔向量; 在tf.where产生真正价值的指标,以及密集的载体tf.gather从原来的张量根据这些指标收集的物品.

在这个例子中,很多东西都是硬编码的,当然,你需要在现实中进行非硬编码,但希望它能显示你想要做的事情的结构(创建一个过滤管道).在实践中,你需要在那里使用QueueRunners来自动调整内容.使用tf.train.batch对于自动处理它非常有用 - 有关更多详细信息,请参阅线程和队列.