Hah*_*pro 3 python indexing pytorch tensor
我使用索引张量在张量中选择了元素。下面的代码我使用索引 0、3、2、1 列表来选择 11、15、2、5
>>> import torch
>>> a = torch.Tensor([5,2,11, 15])
>>> torch.randperm(4)
0
3
2
1
[torch.LongTensor of size 4]
>>> i = torch.randperm(4)
>>> a[i]
11
15
2
5
[torch.FloatTensor of size 4]
Run Code Online (Sandbox Code Playgroud)
我现在有
>>> b = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> b
5 2 11 15
5 2 11 15
5 2 11 15
[torch.FloatTensor of size 3x4]
Run Code Online (Sandbox Code Playgroud)
现在,我想使用索引来选择列 0, 3, 2, 1。换句话说,我想要一个像这样的张量
>>> b
11 15 2 5
11 15 2 5
11 15 2 5
[torch.FloatTensor of size 3x4]
Run Code Online (Sandbox Code Playgroud)
对于这个版本,没有一个简单的方法可以做到这一点。尽管 pytorch 承诺张量操作与 numpy 完全相同,但仍然缺乏一些功能。这是其中之一。
通常,如果您使用 numpy 数组,您将能够相对轻松地完成此操作。就像这样。
>>> i = [2, 1, 0, 3]
>>> a = np.array([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a[:, i]
array([[11, 2, 5, 15],
[11, 2, 5, 15],
[11, 2, 5, 15]])
Run Code Online (Sandbox Code Playgroud)
但是对于张量来说同样的事情会给你一个错误:
>>> i = torch.LongTensor([2, 1, 0, 3])
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a[:,i]
Run Code Online (Sandbox Code Playgroud)
错误:
类型错误:使用 torch.LongTensor 类型的对象索引张量。唯一支持的类型是整数、切片、numpy 标量和 torch.LongTensor 或 torch.ByteTensor 作为唯一参数。
TypeError 告诉您的是,如果您打算使用 LongTensor 或 ByteTensor 进行索引,那么唯一有效的语法是a[<LongTensor>]or a[<ByteTensor>]。除此之外的任何事情都不会起作用。
由于此限制,您有两种选择:
选项 1:转换为 numpy,排列,然后返回 Tensor
>>> i = [2, 1, 0, 3]
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> np_a = a.numpy()
>>> np_a = np_a[:,i]
>>> a = torch.from_numpy(np_a)
>>> a
11 2 5 15
11 2 5 15
11 2 5 15
[torch.FloatTensor of size 3x4]
Run Code Online (Sandbox Code Playgroud)
选项 2:将要排列的暗淡移动到 0,然后执行此操作
您将要移动的暗淡(在您的情况下暗淡= 1)为0,执行排列,然后将其移回。它有点hacky,但它完成了工作。
def hacky_permute(a, i, dim):
a = torch.transpose(a, 0, dim)
a = a[i]
a = torch.transpose(a, 0, dim)
return a
Run Code Online (Sandbox Code Playgroud)
并像这样使用它:
>>> i = torch.LongTensor([2, 1, 0, 3])
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a = hacky_permute(a, i, dim=1)
>>> a
11 2 5 15
11 2 5 15
11 2 5 15
[torch.FloatTensor of size 3x4]
Run Code Online (Sandbox Code Playgroud)
使用张量的直接索引现在可以在此版本中使用。IE。
>>> i = torch.LongTensor([2, 1, 0, 3])
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a[:,i]
11 2 5 15
11 2 5 15
11 2 5 15
[torch.FloatTensor of size 3x4]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
11349 次 |
| 最近记录: |