将所有非零值替换为零,将所有零值替换为特定值

Was*_*mad 6 pytorch

我有一个3d张量,包含一些零和非零值.我想用零和零值替换所有非零值的特定值.我怎样才能做到这一点?

ent*_*phy 6

几乎就是你如何使用numpy来做到这一点,如下所示:

tensor[tensor!=0] = 0
Run Code Online (Sandbox Code Playgroud)

为了替换零和非零,您可以将它们链接在一起.请务必使用张量的副本,因为它们会被修改:

def custom_replace(tensor, on_zero, on_non_zero):
    # we create a copy of the original tensor, 
    # because of the way we are replacing them.
    res = tensor.clone()
    res[tensor==0] = on_zero
    res[tensor!=0] = on_non_zero
    return res
Run Code Online (Sandbox Code Playgroud)

并像这样使用它:

>>>z 
(0 ,.,.) = 
  0  1
  1  3

(1 ,.,.) = 
  0  1
  1  0
[torch.LongTensor of size 2x2x2]

>>>out = custom_replace(z, on_zero=5, on_non_zero=0)
>>>out
(0 ,.,.) = 
  5  0
  0  0

(1 ,.,.) = 
  5  0
  0  5
[torch.LongTensor of size 2x2x2]
Run Code Online (Sandbox Code Playgroud)

  • 小心这个操作是不可区分的!因此,不会发生反向传播. (3认同)

Eli*_*fra 6

使用

torch.where(<your_tensor> != 0, <tensor with zeroz>, <tensor with the value>)
Run Code Online (Sandbox Code Playgroud)

例子:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
         [ 0.3898, -0.7197],
         [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
Tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
Run Code Online (Sandbox Code Playgroud)

查看更多信息:https ://pytorch.org/docs/stable/ generated/torch.where.html