在tensorflow中是否有名为“index_select”的pytorch的等效函数

Sha*_*hen 1 tensorflow pytorch

我尝试将 pytorch 代码翻译为tensorflow。index_select所以我想知道在tensorflow中是否有一个与pytorch等效的函数

zih*_*hao 5

我还没有找到类似的api可以直接实现它,但是我们可以使用tf.slice它来实现它。


def tf_index_select(input_, dim, indices):
    """
    input_(tensor): input tensor
    dim(int): dimension
    indices(list): selected indices list
    """
    shape = input_.get_shape().as_list()
    if dim == -1:
        dim = len(shape)-1
    shape[dim] = 1
    
    tmp = []
    for idx in indices:
        begin = [0]*len(shape)
        begin[dim] = idx
        tmp.append(tf.slice(input_, begin, shape))
    res = tf.concat(tmp, axis=dim)
    
    return res
Run Code Online (Sandbox Code Playgroud)

这是一个显示等效性的示例。


import tensorflow as tf
import torch
import numpy as np

a = np.arange(2*3*4).reshape(2,3,4)
dim = 1
indices = [0,2]
# array([[[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]],

#        [[12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23]]])

# pytorch
res = torch.tensor(a).index_select(dim, torch.tensor(indices))
# tensor([[[ 0,  1,  2,  3],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [20, 21, 22, 23]]])

# tensorflow
res = tf_index_select(tf.constant(a), dim, indices)
# tensor([[[ 0,  1,  2,  3],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [20, 21, 22, 23]]])
Run Code Online (Sandbox Code Playgroud)