小编chi*_*wi9的帖子

如何在 TensorFlow 中对批次进行切片并对每个切片应用操作

我是 TensorFlow 的初学者,我正在尝试实现一个将批处理作为输入的函数。它必须将这个批次分成几个,对它们应用一些操作,然后将它们连接起来以构建一个新的张量以返回。通过我的阅读,我发现有一些已实现的函数,如 input_slice_producer 和 batch_join,但我没有开始使用它们。我附上了我在下面找到的解决方案,但它有点慢,不正确并且无法检测当前批次的大小。有没有人知道这样做的更好方法?

def model(x):

    W_1 = tf.Variable(tf.random_normal([6,1]),name="W_1")
    x_size = x.get_shape().as_list()[0]
    # x is a batch of bigger input of shape [None,6], so I couldn't 
    # get the proper size of the batch when feeding it 
    if x_size == None:
        x_size= batch_size
    #intialize the y_res
    dummy_x = tf.slice(x,[0,0],[1,6])
    result = tf.reduce_sum(tf.mul(dummy_x,W_1))
    y_res = tf.zeros([1], tf.float32)
    y_res = result
    #go throw all slices and concatenate them to get result
    for i in range(1,x_size): 
        dummy_x = tf.slice(x,[i,0],[1,6])
        result = …
Run Code Online (Sandbox Code Playgroud)

concatenation slice tensorflow

4
推荐指数
1
解决办法
4874
查看次数

标签 统计

concatenation ×1

slice ×1

tensorflow ×1