torch.utils.data.random_split() 没有分割数据

use*_*519 2 deep-learning pytorch

I\xe2\x80\x99m 当我使用时没有被分割torch.utils.data.random_split

\n\n

train_size我得到了和的正确数字val_size,但是当我这样做时random_splittrain_data和都val_data得到了full_data。没有发生分裂。

\n\n

请帮我解决这个问题。

\n\n
class DeviceLoader(Dataset):\n\ndef __init__(self, root_dir, train=True, transform=None):\n    self.file_path = root_dir\n    self.train = train\n    self.transform = transform\n    self.file_names = ['%s/%s'%(root,file) for root,_,files in os.walk(root_dir) for file in files]\n    self.len = len(self.file_names)\n    self.labels = {'BP_Raw_Images':0, 'DT_Raw_Images':1, 'GL_Raw_Images':2, 'PO_Raw_Images':3, 'WS_Raw_Images':4}\n\ndef __len__(self):\n    return(len(self.file_names))\n\ndef __getitem__(self, idx):\n    file_name = self.file_names[idx]\n    device = file_name.split('/')[5]\n    img = self.pil_loader(file_name)\n    if(self.transform):\n        img = self.transform(img)\n    cat = self.labels[device]            \n    if(self.train):\n        return(img, cat)\n    else:\n        return(img, file_name)\nfull_data = DeviceLoader(root_dir=\xe2\x80\x99/kaggle/input/devices/dataset/\xe2\x80\x99, transform=transforms, train=True)\ntrain_size = int(0.7*len(full_data))\nval_size = len(full_data) - train_size\ntrain_data, val_data = torch.utils.data.random_split(full_data,[train_size,val_size])\n
Run Code Online (Sandbox Code Playgroud)\n\n

预期结果是将 分成full_data( train_data2000) 和val_data(500)。但相反,我得到full_data在 train 和 val 中都得到了 (2500)。

\n

pau*_*uvo 6

从下图中您可以看到,它实际上生成了数据的子集,但原始数据集仍然存在。这可能会令人困惑。我在 mnist 数据集上做了以下操作

train, validate, test = data.random_split(training_set, [50000, 10000, 10000])
print(len(train))
print(len(validate))
print(len(test))
Run Code Online (Sandbox Code Playgroud)

输出:

50000
10000
10000
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述