我在 pytorch 中编写了一个自定义数据加载器类。但是在迭代一个纪元内的所有批次时,它失败了。例如,假设我有 100 个数据示例,我的批处理大小为 9。它会在第 10 次迭代中失败,说批处理大小不同,这将使批处理大小为 1 而不是 10。我已将自定义数据加载器放在下面。此外,我已将如何从 for 循环内的加载程序中提取数据。
class FlatDirectoryAudioDataset(tdata.Dataset): #customized dataloader
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.files = self.__setup_files()
def __len__(self):
"""
compute the length of the dataset
:return: len => length of dataset
"""
return len(self.files)
def __setup_files(self):
file_names = os.listdir(self.data_dir)
files = [] # initialize to empty list
for file_name in file_names:
possible_file = os.path.join(self.data_dir, file_name)
if os.path.isfile(possible_file) and (file_name.lower().endswith('.wav') or file_name.lower().endswith('.mp3')): #&& (possible_file.lower().endswith('.wav') or possible_file.lower().endswith('.mp3')):
files.append(possible_file)
# return the files list
return files
def __getitem__ (self,index):
sample, _ = librosa.load(self.files[index], 16000)
if self.transform:
sample=self.transform(sample)
sample = torch.from_numpy(sample)
return sample
from torch.utils.data import DataLoader
my_dataset=FlatDirectoryAudioDataset(source_directory,source_folder,source_label,transform = None,label=True)
dataloader_my = DataLoader(
my_dataset,
batch_size=batch_size,
num_workers=0,
shuffle=True)
for (i,batch) in enumerate(dataloader_my,0):
print(i)
if batch.shape[0]!=16:
print(batch.shape)
assert batch.shape[0]==16,"Something wrong with the batch size"
Run Code Online (Sandbox Code Playgroud)
小智 14
使用 drop_last=True utils.DataLoader(dataset,batch_size=batch_size,shuffle = True,drop_last=True)
https://pytorch.org/docs/stable/data.html
设置drop_last=True
为删除最后一个不完整的批次
根据您制作 Dataloader 简化版本的代码,批量大小没有错误。
使用 9 asbatch_size
并有 100 个项目,最后一批只有一个项目。运行它下面的代码会产生。
设置 drop_last=False 打印最后一行并打印“异常”。
0 <class 'torch.Tensor'> torch.Size([9, 1])
1 <class 'torch.Tensor'> torch.Size([9, 1])
2 <class 'torch.Tensor'> torch.Size([9, 1])
3 <class 'torch.Tensor'> torch.Size([9, 1])
4 <class 'torch.Tensor'> torch.Size([9, 1])
5 <class 'torch.Tensor'> torch.Size([9, 1])
6 <class 'torch.Tensor'> torch.Size([9, 1])
7 <class 'torch.Tensor'> torch.Size([9, 1])
8 <class 'torch.Tensor'> torch.Size([9, 1])
9 <class 'torch.Tensor'> torch.Size([9, 1])
10 <class 'torch.Tensor'> torch.Size([9, 1])
# depends on drop_last=True|False
11 <class 'torch.Tensor'> torch.Size([1, 1])
Different batch size (last batch) torch.Size([1, 1])
Run Code Online (Sandbox Code Playgroud)
因此,该批次生产的批次项目足够好,总共可以达到 100 个
0 <class 'torch.Tensor'> torch.Size([9, 1])
1 <class 'torch.Tensor'> torch.Size([9, 1])
2 <class 'torch.Tensor'> torch.Size([9, 1])
3 <class 'torch.Tensor'> torch.Size([9, 1])
4 <class 'torch.Tensor'> torch.Size([9, 1])
5 <class 'torch.Tensor'> torch.Size([9, 1])
6 <class 'torch.Tensor'> torch.Size([9, 1])
7 <class 'torch.Tensor'> torch.Size([9, 1])
8 <class 'torch.Tensor'> torch.Size([9, 1])
9 <class 'torch.Tensor'> torch.Size([9, 1])
10 <class 'torch.Tensor'> torch.Size([9, 1])
# depends on drop_last=True|False
11 <class 'torch.Tensor'> torch.Size([1, 1])
Different batch size (last batch) torch.Size([1, 1])
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
9725 次 |
最近记录: |