将用户指定的参数传递给DataLoader

vpa*_*pap 2 parameter-passing pytorch dataloader

我正在使用 U-Net 并实现 2015 年(U-Net:用于生物医学的卷积网络\n图像分割)和 2019 年(U-Net \xe2\x80\x93 用于细胞计数、检测的深度学习)的论文中描述的加权技术,和形态测量)。在该技术中,存在方差 \xcf\x83 和权重 w_0。我希望,尤其是 \xcf\x83 成为一个可学习的参数,而不是猜测数据集之间哪个值最好。

\n
    \n
  1. 根据我的发现,我可以使用 nn.Parameter 来做到这一点。
  2. \n
  3. 为了使用学习到的 \xcf\x83 从一个纪元到另一个纪元,我需要以某种方式通过 DataLoader 将这个新值传递给 DataSet 的 get_item 函数。
  4. \n
\n

我目前对此的看法是扩展 torch.utils.data.DataLoader ,其中新的init有一个额外的参数接受用户指定/可学习的参数。鉴于 torch.utils.data.DataLoader 的源代码,我不明白 DataLoader 在何处以及如何调用 DataSet 实例并因此传递这些参数。

\n

代码方面,在 DataSet 定义中有该函数

\n
def __getitem__(self, index):\n
Run Code Online (Sandbox Code Playgroud)\n

我可以改变为

\n
def __getitem__(self, index, sigma):\n
Run Code Online (Sandbox Code Playgroud)\n

并利用更新后的、新学习的\xcf\x83。

\n

我的问题是,在训练期间,我迭代训练数据集

\n
for epoch in range( checkpoint[ 'epoch'], num_epochs):\n....\n    for ii, ( X, y, y_weight, fname) in enumerate( dataLoader[ phase]):\n
Run Code Online (Sandbox Code Playgroud)\n

在 DataLoader 的枚举中,如何将新的 \xcf\x83 传递给 DataLoader,以便 DataLoader 将其传递给 DataSet getitem上面提到的

\n

编辑

\n

目前,我在 DataSet 类中定义了一个参数sigma

\n
class MedicalImageDataset( Dataset):\n      def __init__(self, fname, img_transform = None, mask_transform = None, weight_transform = None, sigma = 8):\n            ...\n            self.sigma = sigma\n\n      def __getitem__(self, index):\n            sigma = self.sigma\n            ...\n
Run Code Online (Sandbox Code Playgroud)\n

我通过 DataLoader 更新为

\n
dataLoader[ 'train'].dataset.sigma = model.sigma\n
Run Code Online (Sandbox Code Playgroud)\n

在哪里,

\n
model.sigma\n
Run Code Online (Sandbox Code Playgroud)\n

是一个自定义参数,定义为

\n
model.register_parameter( name = 'sigma', param = torch.nn.Parameter( torch.tensor( 16, dtype = torch.float16), requires_grad = True))\n
Run Code Online (Sandbox Code Playgroud)\n

创建模型后。

\n

我的问题是model.sigma看起来并没有从一个时代更新到另一个时代。具体来说,与初始值相同。为什么是这样?

\n

看了一下,optimizer.state_dict()我找不到任何名为“sigma”的参数,而我可以在model.named_parameters().

\n

最后,这个参数sigma不附加到任何层,它有点“自由”。

\n

Ros*_*osh 5

您需要做的是将 sigma 设置为数据集的属性并在纪元之间更改它。

对于数据集定义

class UNetDataset(object):
    def __init__(self, ..., sigma=5):

        self.sigma = sigma
Run Code Online (Sandbox Code Playgroud)

现在,在 中__getitem__,您可以使用 sigma 值self.sigma

现在,在您的训练周期内,每个时期之后,您可以通过设置数据集的 sigma 属性来更改 sigma 值

for epoch in range(num_epochs):
    dataset.sigma = #whatever value you want

    for i,(x,y) in enumarate(DataLoader):

Run Code Online (Sandbox Code Playgroud)