Mar*_*KZZ 8 python-3.x pytorch
给定一个张量:
A = torch.tensor([2., 3., 4., 5., 6., 7.])
Run Code Online (Sandbox Code Playgroud)
然后,给每个元素A一个 id:
id = torch.arange(A.shape[0], dtype = torch.int) # tensor([0,1,2,3,4,5])
Run Code Online (Sandbox Code Playgroud)
换句话说,2.in的idA为0,3.in的idA为1:
2. -> 0
3. -> 1
4. -> 2
5. -> 3
6. -> 4
7. -> 5
Run Code Online (Sandbox Code Playgroud)
然后,我有一个新的张量:
B = torch.tensor([3., 6., 6., 5., 4., 4., 4.])
Run Code Online (Sandbox Code Playgroud)
在pytorch中,Pytorch中有什么方法可以将B中的每个元素映射到id吗?换句话说,我想获取tensor([1, 4, 4, 3, 2, 2, 2]),其中每个元素都是 中元素的 id B。
我认为 PyTorch 中没有这样的函数来映射张量。
B通过比较每个值 from和 value from 来解决这个问题似乎很不合理B。
以下是解决此问题的两种可能的解决方案。
你可以使用字典。并不是纯粹的 PyTorch解决方案,但很可能是最快、最安全的方法......
只需创建一个字典将每个元素映射到一个 id,然后用它来映射B:
>>> map = {x.item(): i for i, x in enumerate(A)}
>>> torch.tensor([map[x.item()] for x in B])
tensor([1, 4, 4, 3, 2, 2, 2])
Run Code Online (Sandbox Code Playgroud)
仅使用 s 的替代方案torch.Tensor。这将要求您想要映射的值(其内容)A为整数,因为它们将用于索引张量。
将 的内容编码A为 one-hot 编码:
>>> A_enc = torch.zeros((int(A.max())+1,)*2)
>>> A_enc[A, torch.arange(A.shape[0])] = 1
>>> A_enc
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0.]])
Run Code Online (Sandbox Code Playgroud)
我们将使用A_enc作为映射整数的基础:
>>> v = torch.argmax(A_enc, dim=0)
tensor([0, 0, 0, 1, 2, 3, 4, 5])
Run Code Online (Sandbox Code Playgroud)
现在,给定一个整数,例如x=3,我们可以将其编码为 one-hot-encoding:x_enc = [0, 0, 0, 1, 0, 0, 0, 0]。然后,使用v它来绘制地图。使用简单的点积,您可以获得 的映射x_enc: 这里<v/x_enc>给出了1所需的结果(mapped- 的第一个元素B)。但我们将计算和x_enc之间的矩阵乘法,而不是给出。首先编码,然后计算矩阵乘法x :vBBvB_enc
>>> B_enc = torch.zeros(A_enc.shape[0], B.shape[0])
>>> B_enc[B, torch.arange(B.shape[0])] = 1
>>> B_enc
tensor([[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 1., 1.],
[0., 0., 0., 1., 0., 0., 0.],
[0., 1., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]])
>>> v@B_enc.long()
tensor([1, 4, 4, 3, 2, 2, 2])
Run Code Online (Sandbox Code Playgroud)
注意- 您必须使用Long类型定义张量。
| 归档时间: |
|
| 查看次数: |
7863 次 |
| 最近记录: |