Tensorflow收集矩阵的列非常慢

TNM*_*TNM 5 multiple-columns tensorflow tensor

给定两个矩阵A(1000 x 100)和B(100 x 1000),而不是直接在张量流中计算其乘积,即 tf.dot(A,B),我想首先从A中选择10个cols(随机),从B中选择10行,然后使用tf.dot(A_s,B_s)

自然,第二次乘法应该更快,因为所需的乘法次数减少了10倍。

但是,实际上,在张量流中选择矩阵A的给定列来创建A_s似乎是一个效率极低的过程。

给定所需列的索引idx,我尝试了以下解决方案来创建A_s。这些解决方案根据其性能进行排名:

  1. . A_s = tf.transpose(tf.gather(tf.unstack(A, axis=1), idx))

tf.dot(A_s,B_s)tf.dot(A,B)创建A_s太昂贵要慢5倍。

    2。


     p_shape = K.shape(params)
     p_flat = K.reshape(params, [-1])
     i_flat = K.reshape(K.reshape(
        K.arange(0, p_shape[0]) * p_shape[1], [-1, 1]) + indices, [-1])
     indices = [i_flat]
     v = K.transpose(indices)
     updates = i_flat * 0 - 1
     shape = tf.to_int32([p_shape[0] * p_shape[1]])
     scatter = tf.scatter_nd(v, updates, shape) + 1
     out_temp = tf.dynamic_partition(p_flat,
                     partitions=scatter, num_partitions=2)[0]
     A_s = tf.reshape(out_temp, [p_shape[0], -1])

Run Code Online (Sandbox Code Playgroud)

导致产品慢6-7倍

    3。


      X,Y =  tf.meshgrid((tf.range(0, p_shape[0])),indices)
      idx = K.concatenate([K.expand_dims(
           K.reshape((X),[-1]),1), 
           K.expand_dims(K.reshape((Y),[-1]),1)],axis=1)
      A_s = tf.reshape(tf.gather_nd(params, idx), [p_shape[0], -1])

Run Code Online (Sandbox Code Playgroud)

慢10-12倍。

非常感谢我关于如何提高列选择过程效率的任何想法。

PS1:我在CPU上进行了所有实验。

PS2:矩阵A是一个占位符,不是变量。在某些实现中,由于可能无法推断其形状,因此可能会出现问题。