Gab*_*rti 3 python neural-network pytorch
我是 pytorch 的新手,我想了解如何设置网络的第一个隐藏层的初始权重。我解释得更好一点:我的网络是一个非常简单的单层 MLP,有 784 个输入值和 10 个输出值
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
# Dropout module with 0.2 drop probability
self.dropout = nn.Dropout(p=0.2)
def forward(self, x):
# make sure input tensor is flattened
# x = x.view(x.shape[0], -1)
# Now with dropout
x = self.dropout(F.relu(self.fc1(x)))
# output so no dropout here
x = F.log_softmax(self.fc2(x), dim=1)
return x
Run Code Online (Sandbox Code Playgroud)
现在,我有一个形状为 (128, 784) 的 numpy 矩阵,其中包含我想要的 fc1 中的权重值。如何使用矩阵中包含的值初始化第一层的权重?
在其他答案中在线搜索我发现我必须定义权重的初始化函数,例如
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
Run Code Online (Sandbox Code Playgroud)
但我无法理解代码
您可以简单地torch.nn.Parameter()为网络层分配自定义权重。
就像你的情况一样 -
model.fc1.weight = torch.nn.Parameter(custom_weight)
Run Code Online (Sandbox Code Playgroud)
torch.nn.Parameter:一种被视为模块参数的张量。
例如:
# Classifier model
model = Classifier()
# your custom weight, here taking randam
custom_weight = torch.rand(model.fc1.weight.shape)
custom_weight.shape
torch.Size([128, 784])
# before assign custom weight
print(model.fc1.weight)
Parameter containing:
tensor([[ 1.6920e-02, 4.6515e-03, -1.0214e-02, ..., -7.6517e-03,
2.3892e-02, -8.8965e-03],
...,
[-2.3137e-02, 5.8483e-03, 4.4392e-03, ..., -1.6159e-02,
7.9369e-03, -7.7326e-03]])
# assign custom weight to first layer
model.fc1.weight = torch.nn.Parameter(custom_weight)
# after assign custom weight
model.fc1.weight
Parameter containing:
tensor([[ 0.1724, 0.7513, 0.8454, ..., 0.8780, 0.5330, 0.5847],
[ 0.8500, 0.7687, 0.3371, ..., 0.7464, 0.1503, 0.7720],
[ 0.8514, 0.6530, 0.6261, ..., 0.7867, 0.9312, 0.3890],
...,
[ 0.5426, 0.7655, 0.1191, ..., 0.4343, 0.2500, 0.6207],
[ 0.2310, 0.4260, 0.4138, ..., 0.1168, 0.5946, 0.2505],
[ 0.4220, 0.5500, 0.6282, ..., 0.5921, 0.7953, 0.9997]])
Run Code Online (Sandbox Code Playgroud)