the*_*nac 1 pytorch pytorch-dataloader huggingface-datasets
我正在开发一个 ASR 项目,其中使用 HuggingFace ( ) 的模型wav2vec2。我现在的目标是将训练过程转移到 PyTorch,因此我尝试重新创建 HuggingFace\xe2\x80\x99sTrainer()类提供的所有内容。
这些实用程序之一是能够按长度对批次进行分组并将其与动态填充相结合(通过数据整理器)。但说实话,我什至不知道如何在 PyTorch 中开始这一点。
\n在我的例子中,输入是一维数组,表示 .wav 文件的原始波形。因此,在训练之前,我需要确保将相似大小的数组分批在一起。我是否需要创建一个自定义 Dataloader 类并对其进行更改,以便每次它都能为我提供长度尽可能接近的批量大小?
\n我的一个想法是以某种方式将数据从最短到最长(或相反)排序,并每次从中提取batch_size样本。这样,第一批将包含最大长度的样本,第二批将包含第二大长度的样本,依此类推。
\n尽管如此,我不确定如何实现这个实现。任何建议将不胜感激。
\n提前致谢。
\n解决此问题的一种可能方法是使用批处理采样器并为数据加载器实现collate_fn,该数据加载器将对批处理元素执行动态填充。
采用这个基本数据集:
class DS(Dataset):
def __init__(self, files):
super().__init__()
self.len = len(files)
self.files = files
def __getitem__(self, index):
return self.files[index]
def __len__(self):
return self.len
Run Code Online (Sandbox Code Playgroud)
用一些随机数据初始化:
>>> file_len = np.random.randint(0, 100, (16*6))
>>> files = [np.random.rand(s) for s in file_len]
>>> ds = DS(files)
Run Code Online (Sandbox Code Playgroud)
首先定义批量采样器,这本质上是一个可迭代的返回批量索引,供数据加载器用来从数据集中检索元素。正如您所解释的,我们可以对长度进行排序并从此排序构建不同的批次:
>>> batch_size = 16
>>> batches = np.split(file_len.argsort()[::-1], batch_size)
Run Code Online (Sandbox Code Playgroud)
我们应该拥有长度彼此接近的元素。
我们可以实现一个collate_fn函数来组装批处理元素并集成动态填充。这基本上是在数据集和数据加载器之间放置一个额外的用户定义层。目标是找到批次中最长的元素,并用正确的0s 数量填充所有其他元素:
def collate_fn(batch):
longest = max([len(x) for x in batch])
s = np.stack([np.pad(x, (0, longest - len(x))) for x in batch])
return torch.from_numpy(s)
Run Code Online (Sandbox Code Playgroud)
然后你可以初始化一个数据加载器:
>>> dl = DataLoader(dataset=ds, batch_sampler=batches, collate_fn=collate_fn)
Run Code Online (Sandbox Code Playgroud)
并尝试迭代,如您所见,我们得到了长度递减的批次:
>>> for x in dl:
... print(x.shape)
torch.Size([6, 99])
torch.Size([6, 93])
torch.Size([6, 83])
torch.Size([6, 76])
torch.Size([6, 71])
torch.Size([6, 66])
torch.Size([6, 57])
...
Run Code Online (Sandbox Code Playgroud)
但这种方法有一些缺陷,例如元素的分布总是相同的。这意味着您将始终以相同的出现顺序获得相同的批次。这是因为该方法基于数据集中元素的长度进行排序,因此批次的创建没有变化。您可以通过打乱批次来减少这种影响(例如,通过将其包裹batches在 a 内RandomSampler)。然而,正如我所说,批次的内容在整个培训过程中将保持不变,这可能会导致一些问题。
请注意,在数据加载器中使用的batch_sampler是互斥选项batch_size, shuffle, 和sampler!
| 归档时间: |
|
| 查看次数: |
3196 次 |
| 最近记录: |