如何从 Pytorch 张量中删除每列填充零的列?

chi*_*913 4 python pytorch tensor torchtext

我有一个A如下所示的 pytorch 张量:

A = 
tensor([[  4,   3,   3,  ...,   0,   0,   0],
        [ 13,   4,  13,  ...,   0,   0,   0],
        [707, 707,   4,  ...,   0,   0,   0],
        ...,
        [  7,   7,   7,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [195, 195, 195,  ...,   0,   0,   0]], dtype=torch.int32)
Run Code Online (Sandbox Code Playgroud)

我想:

  • 识别所有条目都等于 0 的所有列
  • 仅删除所有条目都等于 0 的那些列

我可以想象这样做:

zero_list = []
for j in range(A.size()[1]):
    if torch.sum(A[:,j]) == 0:
         zero_list = zero_list.append(j)
Run Code Online (Sandbox Code Playgroud)

识别其元素只有 0 的列,但我不确定如何从原始张量中删除这些填充有 0 的列。

如何根据索引号从 pytorch 张量中删除零列?

谢谢你,

Ayd*_*ydo 7

找出所有条目都等于 0 的所有列

non_empty_mask = A.abs().sum(dim=0).bool()

这对每列的绝对值求和,然后将结果转换为布尔值,即False总和是否为零,True否则。

仅删除所有条目都等于 0 的那些列

A[:,non_empty_mask]

这只是将掩码应用于原始张量,即它将行保留在non_empty_maskis 处True