对于张量:
x = torch.tensor([
[
[[0.4495, 0.2356],
[0.4069, 0.2361],
[0.4224, 0.2362]],
[[0.4357, 0.6762],
[0.4370, 0.6779],
[0.4406, 0.6663]]
],
[
[[0.5796, 0.4047],
[0.5655, 0.4080],
[0.5431, 0.4035]],
[[0.5338, 0.6255],
[0.5335, 0.6266],
[0.5204, 0.6396]]
]
])
Run Code Online (Sandbox Code Playgroud)
首先想将其分成 2 个 (x.shape[0]) 张量,然后将它们连接起来。在这里,只要获得正确的输出,我实际上并不需要将其拆分,但在视觉上将其拆分然后将它们连接在一起对我来说更有意义。
例如:
# the shape of the splits are always the same
split1 = torch.tensor([
[[0.4495, 0.2356],
[0.4069, 0.2361],
[0.4224, 0.2362]],
[[0.4357, 0.6762],
[0.4370, 0.6779],
[0.4406, 0.6663]]
])
split2 = torch.tensor([
[[0.5796, 0.4047],
[0.5655, 0.4080],
[0.5431, 0.4035]],
[[0.5338, 0.6255],
[0.5335, 0.6266],
[0.5204, 0.6396]] …Run Code Online (Sandbox Code Playgroud)