PyTorch:按自定义顺序有效地交错两个张量

HMK*_*HMK 5 python torch pytorch tensor

我想z从两个张量创建一个新的张量,例如x和 ,y其尺寸分别为[N_samples, S, N_feats][N_samples, T, N_feats]。目的是通过以特定顺序混合第二个维度的元素来组合第二个维度上的两个张量,该顺序存储在order带有维度的变量中[N_samples, U]

每个样本的排序都不同,基本上是从哪个张量中提取哪个索引。order[0]对于给定的样本- ,它看起来像这样[x_0, x_1, y_0, x_2, y_1, ... ],其中字母表示张量,数字表示第二个暗淡的索引。所以z[0]会是

z[0] = [x[0, 0, :], x[0, 1, :], y[0, 0, :], x[0, 2, :], y[0, 1, :] ... ]

我将如何实现这一目标?我写了一些东西试图torch.gather做到这一点。

x = torch.rand((2, 4, 5))
y = torch.rand((2, 3, 5))

# new ordering of second dim
# positive means take (n-1)th element from x
# negative means take (n-1)th element from y
order = [[1, 2, -1, 3, -2, 4, 3], 
         [1, -1, -2, 2, 3, 4, -3]]

# simple concat for gather
combined = torch.cat([x, y], dim=1)

# add a zero padding on top of combined tensor to ease gather
zero = torch.zeros_like(x)[:, 1:2] 
combined = torch.cat([zero, combined], dim=1)

def _create_index_for_gather(index, offset, n_feats):
    new_index = [abs(i) + offset if i < 0 else i for i in index]

    # need to repeat index for each dim for torch.gather
    new_index = [[x] * n_feats for x in new_index]
    return new_index

_, offset, n_feats = x.shape
index_for_gather = [_create_index_for_gather(i, offset, n_feats) for i in order]

z = combined.gather(dim=1, index=torch.tensor(index_for_gather))
Run Code Online (Sandbox Code Playgroud)

有没有更有效的方法来做到这一点?