Ros*_*oss 11 python casting boolean pytorch tensor
我想将整数张量转换为布尔值张量。
具体来说,我希望能够拥有一个转换tensor([0,10,0,16])
为tensor([0,1,0,1])
这在 Tensorflow 中只需使用tf.cast(x,tf.bool)
.
我希望强制转换将所有大于 0 的整数更改为 1,并将所有等于 0 的整数更改为 0。这!!
在大多数语言中都是等价的。
由于 pytorch 似乎没有专用的布尔类型可以转换,这里最好的方法是什么?
编辑:我正在寻找一个矢量化的解决方案,而不是遍历每个元素。
kma*_*o23 16
您正在寻找的是为给定的整数张量生成一个布尔掩码。为此,您可以使用简单的比较运算符 ( >
) 或 using简单地检查条件:“张量中的值是否大于 0” torch.gt()
,这将给我们所需的结果。
# input tensor
In [76]: t
Out[76]: tensor([ 0, 10, 0, 16])
# generate the needed boolean mask
In [78]: t > 0
Out[78]: tensor([0, 1, 0, 1], dtype=torch.uint8)
Run Code Online (Sandbox Code Playgroud)
# sanity check
In [93]: mask = t > 0
In [94]: mask.type()
Out[94]: 'torch.ByteTensor'
Run Code Online (Sandbox Code Playgroud)
注意:在 PyTorch 1.4+ 版本中,上述操作会返回'torch.BoolTensor'
In [9]: t > 0
Out[9]: tensor([False, True, False, True])
# alternatively, use `torch.gt()` API
In [11]: torch.gt(t, 0)
Out[11]: tensor([False, True, False, True])
Run Code Online (Sandbox Code Playgroud)
如果您确实想要单个位(0
s 或1
s),请使用:
In [14]: (t > 0).type(torch.uint8)
Out[14]: tensor([0, 1, 0, 1], dtype=torch.uint8)
# alternatively, use `torch.gt()` API
In [15]: torch.gt(t, 0).int()
Out[15]: tensor([0, 1, 0, 1], dtype=torch.int32)
Run Code Online (Sandbox Code Playgroud)
此功能请求问题中已讨论了此更改的原因:issues/4764 - Introduce torch.BoolTensor ...
TL;DR : 简单的一个衬垫
t.bool().int()
Run Code Online (Sandbox Code Playgroud)
PyTorch 的to(dtype)
方法具有名为 aliases 的便捷数据类型。您可以简单地致电bool
:
>>> t.bool()
tensor([False, True, False, True])
Run Code Online (Sandbox Code Playgroud)
>>> t.bool().int()
tensor([0, 1, 0, 1], dtype=torch.int32)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
23758 次 |
最近记录: |