小编Luc*_*toe的帖子

PyTorch Lightning trainer.fit 卡在 epoch 0

我试图使用 PyTorch 和 PyTorch Lightning 制作多输入模型,但我不明白为什么训练器卡在 epoch 0。我试图将此代码从 TensorFlow 迁移到 PyTorch,但 PyTorch 学习曲线是有点陡,我不知道从这里该去哪里。

RC_train_config = config.init_dataset_config(
'RC',
'GI4E',
'label',
16,
lr = 0.001,
epochs = 500,
train_ratio = 0.8
Run Code Online (Sandbox Code Playgroud)

模型的配置,包括超参数和使用的数据集。它也用于数据选择,因为不同的数据集需要不同的处理方法。

class RCDataset(Dataset):
def __init__(self, config_dataset):
    super().__init__()
    self.config_dataset = config_dataset
    
    # Image-handling
    if self.config_dataset['dataset'] == 'all':
        pass
    elif self.config_dataset['dataset'] == 'BIOID':
        if self.config_dataset['mode'] == 'label':
            pass
        elif self.config_dataset['mode'] == 'filter':
            pass
    elif self.config_dataset['dataset'] == 'GI4E':
        if self.config_dataset['mode'] == 'label':
            image1_noteye_paths = glob(C.WORKING_DATASETS['GI4E']['images_label'] + '/0/noteye/*')
            image1_eye_paths = glob(C.WORKING_DATASETS['GI4E']['images_label'] + '/0/left/*')
            image1_eye_paths += …
Run Code Online (Sandbox Code Playgroud)

python machine-learning conv-neural-network pytorch pytorch-lightning

5
推荐指数
0
解决办法
1509
查看次数