如何帮助 tqdm 计算自定义迭代器中的总数

Dua*_*ane 5 python tqdm

我正在实现我自己的迭代器。tqdm 不显示进度条,因为它不知道列表中元素的总数。我不想使用“total=”,因为它看起来很丑。相反,我更愿意在我的迭代器中添加一些 tqdm 可以用来计算总数的东西。

class Batches:
    def __init__(self, batches, target_input):
        self.batches = batches
        self.pos = 0
        self.target_input = target_input

    def __iter__(self):
        return self

    def __next__(self):
        if self.pos < len(self.batches):
            minibatch = self.batches[self.pos]
            target = minibatch[:, :, self.target_input]
            self.pos += 1
            return minibatch, target
        else:
            raise StopIteration

    def __len__(self):
        return self.batches.len()
Run Code Online (Sandbox Code Playgroud)

这甚至可能吗?在上面的代码中添加什么...

使用 tqdm 如下所示..

for minibatch, target in tqdm(Batches(test, target_input)):

    output = lstm(minibatch)
    loss = criterion(output, target)
    writer.add_scalar('loss', loss, tensorboard_step)
Run Code Online (Sandbox Code Playgroud)

eme*_*mem 17

我知道已经有一段时间了,但我一直在寻找相同的答案,这是解决方案。而不是像这样用 tqdm 包装你的迭代

for i in tqdm(my_iterable):
    do_something()
Run Code Online (Sandbox Code Playgroud)

改用“with”关闭,如:

with tqdm(total=len(my_iterable)) as progress_bar:
    for i in my_iterable:
        do_something()
        progress_bar.update(1) # update progress
Run Code Online (Sandbox Code Playgroud)

对于您的批次,您可以将总数设置为批次数,并更新为 1(如上)。或者,您可以将总数设置为实际的项目总数,并将更新设置为当前处理批次的大小。

  • 只是注意,如果没有特殊的更新需求,可以简单地将总计添加到 tqdm 构造函数中,即使用:`for i in tqdm(my_iterable,total=my_total)` (2认同)

Nic*_*ens 11

原来的问题是:

我不想使用“total=”,因为它看起来很难看。相反,我更愿意在迭代器中添加一些内容,tqdm 可以使用它来计算总数。

但是,当前接受的答案明确指出要使用total

with tqdm(total=len(my_iterable)) as progress_bar:
Run Code Online (Sandbox Code Playgroud)

事实上,给定的示例比实际需要的更复杂,因为原始问题并未要求对栏进行复杂的更新。因此,

for i in tqdm(my_iterable, total=my_total):
    do_something()
Run Code Online (Sandbox Code Playgroud)

实际上已经足够了(正如作者@emem,已经在评论中指出的那样)。


这个问题相对较老(撰写本文时为 4 年),但是查看 tqdm 的代码,可以看到从一开始(撰写本文时为 8 年前),该行为就默认为以防total = len(iterable)万一total没有给出。

因此,问题的正确答案是实施__len__。正如问题中所述,原始示例已经实现了。因此,它应该已经可以正常工作了。

可以在下面找到测试行为的完整玩具示例(请注意方法上方的注释__len__):

from time import sleep
from tqdm import tqdm


class Iter:

    def __init__(self, n=10):
        self.n = n
        self.iter = iter(range(n))

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.iter)

    # commenting the next two lines disables showing the bar
    # due to tqdm not knowing the total number of elements:
    def __len__(self):
        return self.n


it = Iter()
for i in tqdm(it):
    sleep(0.2)
Run Code Online (Sandbox Code Playgroud)

看看 tqdm 到底做了什么:

try:
    total = len(iterable)
except (TypeError, AttributeError):
    total = None
Run Code Online (Sandbox Code Playgroud)

...并且由于我们不确切知道 @Duane 用作什么batches,我认为这基本上只是一个隐藏得很好的拼写错误 ( self.batches.len()),这会导致AttributeErrortqdm 中捕获到 。

如果batches只是一个序列类型,那么这可能是预期的定义:

    def __len__(self):
        return len(self.batches)
Run Code Online (Sandbox Code Playgroud)

__next__(using )的定义len(self.batches)也指向这个方向。