使用 PyTorch 张量将对角线屏蔽为特定值

Nic*_*aiF 5 python diagonal torch pytorch

如何用 torch 中的值填充对角线?在 numpy 中你可以这样做:

a = np.zeros((3, 3), int)
np.fill_diagonal(a, 5)

array([[5, 0, 0],
       [0, 5, 0],
       [0, 0, 5]])
Run Code Online (Sandbox Code Playgroud)

我知道torch.diag()返回对角线,但如何使用它作为掩码来分配新值超出了我的范围。我无法在这里或 PyTorch 文档中找到答案。

uke*_*emi 4

您可以在 PyTorch 中使用以下命令执行此操作fill_diagonal_

>>> a = torch.zeros(3, 3)
>>> a.fill_diagonal_(5)
tensor([[5, 0, 0],
        [0, 5, 0],
        [0, 0, 5]])
Run Code Online (Sandbox Code Playgroud)