如何在 Pytorch 中相互交换 3 个维度?

Aud*_*ey 8 python pytorch

我有一个a= torch.randn(28, 28, 8),我想将尺寸交换(0, 1, 2)(2, 0, 1)。我尝试过b = a.transpose(2, 0, 1),但收到此错误:

TypeError: transpose() received an invalid combination of arguments - got 
(int, int, int), but expected one of:
 * (name dim0, name dim1)
 * (int dim0, int dim1)
Run Code Online (Sandbox Code Playgroud)

有什么办法可以一次性全部换掉吗?

Vig*_*n C 8

你可以使用Pytorch的permute()功能一次性全部交换,

>>>a = torch.randn(28, 28, 8)
>>>b = a.permute(2, 0, 1)
>>>b.shape
torch.Size([8, 28, 28])
Run Code Online (Sandbox Code Playgroud)