我不明白当数据是 3D 时 BatchNorm1d 如何工作(批量大小、H、W)。
例子
如果我随后包含批量归一化层,则需要 num_features=50:
我不明白为什么不是 20:
示例1)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.bn11 = nn.BatchNorm1d(50)
self.fc11 = nn.Linear(70,20)
def forward(self, inputs):
out = self.fc11(inputs)
out = torch.relu(self.bn11(out))
return out
model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)
Run Code Online (Sandbox Code Playgroud)
示例2)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.bn11 = nn.BatchNorm1d(20)
self.fc11 = nn.Linear(70,20)
def forward(self, inputs):
out = self.fc11(inputs)
out = torch.relu(self.bn11(out))
return out
model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)
Run Code Online (Sandbox Code Playgroud)
2D 示例:
我认为 BN 层中的 20 是因为线性层输出了 20 个节点,并且每个节点都需要一个运行的平均值/标准差来输入值。
为什么在 3D 情况下,如果线性层有 20 个输出节点,BN 层没有 20 个特征?
torch.nn.Linear人们可以在文档中找到答案。
它采用input形状(N, *, I)并返回(N, *, O),其中I代表输入维度和O输出维度,并且*是之间的任何维度。
如果你torch.Tensor(2,50,70)传入nn.Linear(70,20),你会得到 shape 的输出(2, 50, 20),当你使用BatchNorm1d它时,它会计算第一个非批量维度的运行平均值,所以它将是 50。这就是你的错误背后的原因。
| 归档时间: |
|
| 查看次数: |
12347 次 |
| 最近记录: |