Tensorflow:根据偶数和奇数索引合并两个二维张量

Slo*_*oke 5 python tensorflow

我想检查批次中的偶数和奇数元素,并在需要时交换它们。我设法得到两个我想交织的张量:

def tf_oplu(x, name=None):   


    even = x[:,::2] #slicing into odd and even parts on the batch
    odd = x[:,1::2]

    even_flatten = tf.reshape(even, [-1]) # flatten tensors 
    #in row-major order to apply function across them

    odd_flatten = tf.reshape(odd, [-1])

    compare = tf.to_float(even_flatten<odd_flatten)
    compare_not = tf.to_float(even_flatten>=odd_flatten)

    #def oplu(x,y): # trivial function  
    #    if x<y : # (x<y)==1
    #       return y, x 
    #    else:
    #       return x, y # (x<y)==0

    even_flatten_new = odd_flatten * compare + even_flatten * compare_not
    odd_flatten_new = odd_flatten * compare_not + even_flatten * compare

    # convolute back 

    even_new = tf.reshape(even_flatten_new,[100,128])
    odd_new = tf.reshape(odd_flatten_new,[100,128])
Run Code Online (Sandbox Code Playgroud)

现在我想找回[[100,256]]张量,并填充偶数和奇数位。在numpy中,我当然会做:

y = np.empty((even_new.size + odd_newsize,), dtype=even_new.dtype)
y[:,0::2] = even_new
y[:,1::2] = odd_new

return y   
Run Code Online (Sandbox Code Playgroud)

但是这种事情对于Tensoflow是不可能的,因为张量不可修改。我想使用稀疏张量tf.gather_nd都是可能的,但是两者都需要生成索引数组,这对我来说也不是简单的任务。还有一点要注意:我不想通过使用任何python函数tf.py_func,因为我检查了它们仅在CPU上运行。也许lambda tf.map_fn可能会有所帮助?谢谢!

P-G*_*-Gn 5

要垂直交织两个矩阵,请不要使用诸如gather或的大枪map_fn。您可以按如下所示简单地交错它们:

tf.reshape(
  tf.stack([even_new, odd_new], axis=1),
  [-1, tf.shape(even_new)[1]])
Run Code Online (Sandbox Code Playgroud)

编辑

要水平交错它们:

tf.reshape(
  tf.concat([even_new[...,tf.newaxis], odd_new[...,tf.newaxis]], axis=-1), 
  [tf.shape(even_new)[0],-1])
Run Code Online (Sandbox Code Playgroud)

这个想法是使用堆栈将它们插入内存中。发生堆栈的尺寸给出了交织的粒度。如果我们在处堆叠axis=0,则交织发生在每个元素上,混合列。如果我们在处堆叠axis=1,则整个输入行保持连续,行之间会发生交织。