过滤pytorch张量中的数据

dod*_*ong 15 python pytorch

我有一个X像的张量[0.1, 0.5, -1.0, 0, 1.2, 0],我想实现一个名为 的函数filter_positive(),它可以将正数据过滤成一个新的张量并返回原始张量的索引。例如:

new_tensor, index = filter_positive(X)

new_tensor = [0.1, 0.5, 1.2]
index = [0, 1, 4]
Run Code Online (Sandbox Code Playgroud)

如何在 pytorch 中最有效地实现此功能?

nem*_*emo 18

看看torch.nonzero哪个大致相当于np.where. 它将二进制掩码转换为索引:

>>> X = torch.tensor([0.1, 0.5, -1.0, 0, 1.2, 0])
>>> mask = X >= 0
>>> mask
tensor([1, 1, 0, 1, 1, 1], dtype=torch.uint8)

>>> indices = torch.nonzero(mask)
>>> indices
tensor([[0],
        [1],
        [3],
        [4],
        [5]])

>>> X[indices]
tensor([[0.1000],
        [0.5000],
        [0.0000],
        [1.2000],
        [0.0000]])
Run Code Online (Sandbox Code Playgroud)

一个解决方案是这样写:

mask = X >= 0
new_tensor = X[mask]
indices = torch.nonzero(mask)
Run Code Online (Sandbox Code Playgroud)


Dio*_*ago 5

如果不需要索引,你可以这样做:

X = X[X > 0]
Run Code Online (Sandbox Code Playgroud)