vpa*_*pap 2 parameter-passing pytorch dataloader
我正在使用 U-Net 并实现 2015 年(U-Net:用于生物医学的卷积网络\n图像分割)和 2019 年(U-Net \xe2\x80\x93 用于细胞计数、检测的深度学习)的论文中描述的加权技术,和形态测量)。在该技术中,存在方差 \xcf\x83 和权重 w_0。我希望,尤其是 \xcf\x83 成为一个可学习的参数,而不是猜测数据集之间哪个值最好。
\n我目前对此的看法是扩展 torch.utils.data.DataLoader ,其中新的init有一个额外的参数接受用户指定/可学习的参数。鉴于 torch.utils.data.DataLoader 的源代码,我不明白 DataLoader 在何处以及如何调用 DataSet 实例并因此传递这些参数。
\n代码方面,在 DataSet 定义中有该函数
\ndef __getitem__(self, index):\nRun Code Online (Sandbox Code Playgroud)\n我可以改变为
\ndef __getitem__(self, index, sigma):\nRun Code Online (Sandbox Code Playgroud)\n并利用更新后的、新学习的\xcf\x83。
\n我的问题是,在训练期间,我迭代训练数据集
\nfor epoch in range( checkpoint[ 'epoch'], num_epochs):\n....\n for ii, ( X, y, y_weight, fname) in enumerate( dataLoader[ phase]):\nRun Code Online (Sandbox Code Playgroud)\n在 DataLoader 的枚举中,如何将新的 \xcf\x83 传递给 DataLoader,以便 DataLoader 将其传递给 DataSet getitem上面提到的
\n编辑
\n目前,我在 DataSet 类中定义了一个参数sigma
\nclass MedicalImageDataset( Dataset):\n def __init__(self, fname, img_transform = None, mask_transform = None, weight_transform = None, sigma = 8):\n ...\n self.sigma = sigma\n\n def __getitem__(self, index):\n sigma = self.sigma\n ...\nRun Code Online (Sandbox Code Playgroud)\n我通过 DataLoader 更新为
\ndataLoader[ 'train'].dataset.sigma = model.sigma\nRun Code Online (Sandbox Code Playgroud)\n在哪里,
\nmodel.sigma\nRun Code Online (Sandbox Code Playgroud)\n是一个自定义参数,定义为
\nmodel.register_parameter( name = 'sigma', param = torch.nn.Parameter( torch.tensor( 16, dtype = torch.float16), requires_grad = True))\nRun Code Online (Sandbox Code Playgroud)\n创建模型后。
\n我的问题是model.sigma看起来并没有从一个时代更新到另一个时代。具体来说,与初始值相同。为什么是这样?
看了一下,optimizer.state_dict()我找不到任何名为“sigma”的参数,而我可以在model.named_parameters().
最后,这个参数sigma不附加到任何层,它有点“自由”。
\n您需要做的是将 sigma 设置为数据集的属性并在纪元之间更改它。
对于数据集定义
class UNetDataset(object):
def __init__(self, ..., sigma=5):
self.sigma = sigma
Run Code Online (Sandbox Code Playgroud)
现在,在 中__getitem__,您可以使用 sigma 值self.sigma
现在,在您的训练周期内,每个时期之后,您可以通过设置数据集的 sigma 属性来更改 sigma 值
for epoch in range(num_epochs):
dataset.sigma = #whatever value you want
for i,(x,y) in enumarate(DataLoader):
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3330 次 |
| 最近记录: |