如何在TensorFlow中的SparseTensor中选择一行?

sol*_*ice 2 python embedding tensorflow

说,如果我有两个SparseTensor如下:

[[1, 0, 0, 0],
 [2, 0, 0, 0],
 [1, 2, 0, 0]]
Run Code Online (Sandbox Code Playgroud)

[[1.0, 0, 0, 0],
 [1.0, 0, 0, 0],
 [0.3, 0.7, 0, 0]]
Run Code Online (Sandbox Code Playgroud)

我想从其中提取前两行。我需要索引和非零项的值都为SparseTensors,以便可以将结果传递给tf.nn.embedding_lookup_sparse。我怎样才能做到这一点?

我的应用程序是:我想使用单词嵌入,这在TensorFlow中非常简单。但是现在我想使用稀疏嵌入,即:对于普通单词,它们具有自己的嵌入。对于稀有词,它们的嵌入是常见词的嵌入的稀疏线性组合。因此,我需要两本食谱来说明稀疏嵌入的组成方式。在上述示例中,菜谱说:对于第一个单词,其嵌入由权重为1.0的自身嵌入组成。第二个单词的情况相似。对于最后一个单词,它说:该单词的嵌入是前两个单词的嵌入的线性组合,并且相应的权重分别为0.3和0.7。我需要提取一行,然后将索引和权重输入tf.nn.embedding_lookup_sparse获得最终的嵌入。我如何在TensorFlow中做到这一点?

还是我需要解决它,即:预处理数据并使用TensorFlow处理菜谱?

Pet*_*den 5

我在这里与其中一位对此领域有更多了解的工程师签到,这是他所传递的内容:

我不确定我们是否可以有效地实现这一点,但是这是使用dynamic_partition和collect ops的非理想实现。

def sparse_slice(indices, values, needed_row_ids):
   num_rows = tf.shape(indices)[0]
   partitions = tf.cast(tf.equal(indices[:,0], needed_row_ids), tf.int32)
   rows_to_gather = tf.dynamic_partition(tf.range(num_rows), partitions, 2)[1]
   slice_indices = tf.gather(indices, rows_to_gather)
   slice_values = tf.gather(values, rows_to_gather)
   return slice_indices, slice_values

with tf.Session().as_default():
  indices = tf.constant([[0,0], [1, 0], [2, 0], [2, 1]])
  values = tf.constant([1.0, 1.0, 0.3, 0.7], dtype=tf.float32)
  needed_row_ids = tf.constant([1])
  slice_indices, slice_values = sparse_slice(indices, values, needed_row_ids)
  print(slice_indices.eval(), slice_values.eval())
Run Code Online (Sandbox Code Playgroud)

更新:

工程师也发送了一个示例来帮助处理多行,感谢您指出这一点!

def sparse_slice(indices, values, needed_row_ids):
  needed_row_ids = tf.reshape(needed_row_ids, [1, -1])
  num_rows = tf.shape(indices)[0]
  partitions = tf.cast(tf.reduce_any(tf.equal(tf.reshape(indices[:,0], [-1, 1]), needed_row_ids), 1), tf.int32)
  rows_to_gather = tf.dynamic_partition(tf.range(num_rows), partitions, 2)[1]
  slice_indices = tf.gather(indices, rows_to_gather)
  slice_values = tf.gather(values, rows_to_gather)
  return slice_indices, slice_values

with tf.Session().as_default():
  indices = tf.constant([[0,0], [1, 0], [2, 0], [2, 1]])
  values = tf.constant([1.0, 1.0, 0.3, 0.7], dtype=tf.float32)
  needed_row_ids = tf.constant([0, 2])
Run Code Online (Sandbox Code Playgroud)