如何指定pytorch nn.Linear的输入维度?

yi *_* li 1 python deep-learning pytorch

例如,我定义了一个模型,如下所示:

class Net(nn.module):
   def __init__():
      self.conv11 = nn.Conv2d(in_channel,out1_channel,3)
      self.conv12 = nn.Conv2d(...)
      self.conv13 = nn.Conv2d(...)
      self.conv14 = nn.Conv2d(...)
      ...
      #Here is the point
      flat = nn.Flatten()
      #I don't want to compute the size of data after flatten, but I need a linear layer.

      fc_out = nn.Linear(???,out_dim)
Run Code Online (Sandbox Code Playgroud)

问题是线性层,我不想计算线性层输入的大小,但定义模型需要指定它。我怎么解决这个问题?

uke*_*emi 7

如果您不想显式计算输入大小,您可以使用一个LazyLinear模块(而不是 a Linear)来推断它:

fc_out = nn.LazyLinear(out_dim)
Run Code Online (Sandbox Code Playgroud)