pytorch中当输入参数超过两个时如何使用forward()方法

Vik*_*ain 7 python neural-network deep-learning pytorch tensor

有人可以告诉我方法中多个参数背后的概念吗forward()?一般来说,方法的实现forward()有两个参数

  1. 自己
  2. 输入

如果前向方法的参数多于这些参数,PyTorch 如何使用前向方法。

让我们考虑这个代码库: https://github.com/bamps53/kaggle-autonomous-driven2019/blob/master/models/centernet.py 这里在线 236 位作者使用了带有两个参数的前向方法:

  1. 中心
  2. 返回嵌入

我找不到一篇文章可以回答我关于第 254( return_embeddings:) 行和第 257( if centers is not None:) 行将在什么条件下执行的查询。据我所知,该方法由 nn 模块内部调用。有人可以在这上面放一些灯吗?

anl*_*ses 9

转发功能由您设置。这意味着您可以根据需要添加更多参数。例如,您可以添加如下所示的输入

def forward(self, input1, input2, input3):
    x = self.layer1(input1)
    y = self.layer2(input2)
    z = self.layer3(input3)

    net = torch.cat((x,y,z),1)
     
    return net
Run Code Online (Sandbox Code Playgroud)

关键点是您必须在为网络提供数据时控制参数。图层只能输入一个参数。因此,您需要从输入中一一提取特征并将其与torch.cat((x,y),1)(1 代表维度)它们连接起来。