如何从pytorch dataloader获取批量迭代的总数?

Hyu*_*Kim 3 for-loop pytorch dataloader

我有一个问题,如何从 pytorch 数据加载器获取批量迭代的总数?

以下是训练的常用代码

for i, batch in enumerate(dataloader):
Run Code Online (Sandbox Code Playgroud)

那么,有没有什么方法可以获取“for循环”的总迭代次数?

在我的 NLP 问题中,总迭代次数不同于 int(n_train_samples/batch_size)...

例如,如果我只截断训练数据 10,000 个样本并将批大小设置为 1024,那么在我的 NLP 问题中会发生 363 次迭代。

我想知道如何获得“for 循环”中的总迭代次数。

谢谢你。

hkc*_*rex 9

len(dataloader)返回批次总数。这取决于__len__数据集的功能,因此请确保正确设置。