我有一个X
像的张量[0.1, 0.5, -1.0, 0, 1.2, 0]
,我想实现一个名为 的函数filter_positive()
,它可以将正数据过滤成一个新的张量并返回原始张量的索引。例如:
new_tensor, index = filter_positive(X)
new_tensor = [0.1, 0.5, 1.2]
index = [0, 1, 4]
Run Code Online (Sandbox Code Playgroud)
如何在 pytorch 中最有效地实现此功能?
nem*_*emo 18
看看torch.nonzero
哪个大致相当于np.where
. 它将二进制掩码转换为索引:
>>> X = torch.tensor([0.1, 0.5, -1.0, 0, 1.2, 0])
>>> mask = X >= 0
>>> mask
tensor([1, 1, 0, 1, 1, 1], dtype=torch.uint8)
>>> indices = torch.nonzero(mask)
>>> indices
tensor([[0],
[1],
[3],
[4],
[5]])
>>> X[indices]
tensor([[0.1000],
[0.5000],
[0.0000],
[1.2000],
[0.0000]])
Run Code Online (Sandbox Code Playgroud)
一个解决方案是这样写:
mask = X >= 0
new_tensor = X[mask]
indices = torch.nonzero(mask)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
7846 次 |
最近记录: |