gas*_*oon 8 python deep-learning pytorch tensor
让我们调用我正在寻找的函数“ magic_combine”,它可以组合我给它的张量的连续维度。更具体地说,我希望它做以下事情:
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
Run Code Online (Sandbox Code Playgroud)
我知道torch.view()可以做类似的事情。但我只是想知道是否有更优雅的方式来实现目标?
小智 16
flatten该start_dim参数有一个变体end_dim。您可以按照与您的相同的方式调用它magic_combine(除了end_dim包含在内)。
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.flatten(2, 4) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
Run Code Online (Sandbox Code Playgroud)
https://pytorch.org/docs/stable/ generated/torch.flatten.html
还有一个相应的unflatten,您可以在其中指定要展开的尺寸和要展开的形状。
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.view(*a.shape[:2], -1, *a.shape[5:])
Run Code Online (Sandbox Code Playgroud)
在我看来,比当前接受的答案简单一点,并且没有通过list构造函数(3 次)。
我不确定您对“更优雅的方式”的想法是什么,但Tensor.view()其优点是不为视图重新分配数据(原始张量和视图共享相同的数据),使得此操作相当轻量级。
正如@UmangGupta 所提到的,包装这个函数来实现你想要的东西是相当简单的,例如:
import torch
def magic_combine(x, dim_begin, dim_end):
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
return x.view(combined_shape)
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
7982 次 |
| 最近记录: |