torch.stack()和torch.cat()函数有什么区别?

Gul*_*zar 16 python machine-learning pytorch

OpenAI的REINFORCE和演员批评强化学习的例子有以下代码:

加强:

policy_loss = torch.cat(policy_loss).sum()
Run Code Online (Sandbox Code Playgroud)

演员评论家:

loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
Run Code Online (Sandbox Code Playgroud)

一个是使用torch.cat,另一个是使用torch.stack.

就我的理解而言,该文件并未对它们作出任何明确的区分.

我很高兴知道这些功能之间的差异.

Jat*_*aki 33

stack

沿新维度连接张量序列.

cat

在给定维度中连接给定的seq张量序列.

因此,如果A并且B具有形状(3,4),torch.cat([A, B], dim=0)则将具有形状(6,4)并且torch.stack([A, B], dim=0)将具有形状(2,3,4).

  • 因此, torch.stack([A,B],dim = 0) 相当于 torch.cat([A.unsqueeze(0),b.unsqueeze(0)],dim = 0) 。因此,如果您发现自己在组合张量之前执行了许多 unsqueeze() 操作,则可以使用 stack() 简化代码。 (7认同)
  • 作为补充,在问题中的 OpenAI 示例中,“torch.stack”和“torch.cat”可以在任一代码行中互换使用,因为“torch.stack(tensors).sum() == torch.cat(tensors ).sum()`。 (3认同)

uke*_*emi 9

t1 = torch.tensor([[1, 2],
                   [3, 4]])

t2 = torch.tensor([[5, 6],
                   [7, 8]])
Run Code Online (Sandbox Code Playgroud)
torch.stack torch.cat
沿新维度“堆叠”一系列张量:

在此处输入图片说明



'Con cat生成'沿现有维度的一系列张量:

在此处输入图片说明

这些函数类似于numpy.stacknumpy.concatenate