laz*_*aza 9 python diagonal pytorch
我一直在到处寻找与 PyTorch 等效的东西,但我找不到任何东西。
L_1 = np.tril(np.random.normal(scale=1., size=(D, D)), k=0)
L_1[np.diag_indices_from(L_1)] = np.exp(np.diagonal(L_1))
Run Code Online (Sandbox Code Playgroud)
我想使用 Pytorch 无法以如此优雅的方式替换对角线元素。
我认为目前还没有实现这样的功能。mask但是,您可以使用以下方法实现相同的功能。
# Assuming v to be the vector and a be the tensor whose diagonal is to be replaced
mask = torch.diag(torch.ones_like(v))
out = mask*torch.diag(v) + (1. - mask)*a
Run Code Online (Sandbox Code Playgroud)
所以,你的实现将类似于
L_1 = torch.tril(torch.randn((D, D)))
v = torch.exp(torch.diag(L_1))
mask = torch.diag(torch.ones_like(v))
L_1 = mask*torch.diag(v) + (1. - mask)*L_1
Run Code Online (Sandbox Code Playgroud)
不像 numpy 那样优雅,但也不算太糟糕。
| 归档时间: |
|
| 查看次数: |
12304 次 |
| 最近记录: |