PyTorch 张量的对角线为零?

gus*_*gus 3 python diagonal pytorch tensor

有没有一种简单的方法可以将 PyTorch 张量的对角线归零?

例如我有:

tensor([[2.7183, 0.4005, 2.7183, 0.5236],
        [0.4005, 2.7183, 0.4004, 1.3469],
        [2.7183, 0.4004, 2.7183, 0.5239],
        [0.5236, 1.3469, 0.5239, 2.7183]])
Run Code Online (Sandbox Code Playgroud)

我想得到:

tensor([[0.0000, 0.4005, 2.7183, 0.5236],
        [0.4005, 0.0000, 0.4004, 1.3469],
        [2.7183, 0.4004, 0.0000, 0.5239],
        [0.5236, 1.3469, 0.5239, 0.0000]])
Run Code Online (Sandbox Code Playgroud)

tri*_*ror 5

我相信最简单的方法是使用torch.diagonal

z = torch.randn(4,4)
torch.diagonal(z, 0).zero_()
print(z)
>>> tensor([[ 0.0000, -0.6211,  0.1120,  0.8362],
            [-0.1043,  0.0000,  0.1770,  0.4197],
            [ 0.7211,  0.1138,  0.0000, -0.7486], 
            [-0.5434, -0.8265, -0.2436,  0.0000]])
Run Code Online (Sandbox Code Playgroud)

这样,代码就完全明确了,并且您可以将性能委托给 pytorch 的内置函数。


uke*_*emi 5

您可以简单地使用:

x.fill_diagonal_(0)
Run Code Online (Sandbox Code Playgroud)