相关疑难解决方法(0)

了解torch.nn.Parameter

我是pytorch的新手,我很难理解它是如何torch.nn.Parameter()工作的.

我已经浏览了https://pytorch.org/docs/stable/nn.html中的文档,但可能会对此有所了解.

有人可以帮忙吗?

我正在处理的代码片段:

def __init__(self, weight):
    super(Net, self).__init__()
    # initializes the weights of the convolutional layer to be the weights of the 4 defined filters
    k_height, k_width = weight.shape[2:]
    # assumes there are 4 grayscale filters
    self.conv = nn.Conv2d(1, 4, kernel_size=(k_height, k_width), bias=False)
    self.conv.weight = torch.nn.Parameter(weight)
Run Code Online (Sandbox Code Playgroud)

python pytorch

25
推荐指数
2
解决办法
2万
查看次数

标签 统计

python ×1

pytorch ×1