如果我只对一些样本进行转发,什么时候计算图会被释放?

jdh*_*hao 5 python backpropagation pytorch

我有一个用例,我对批次中的每个样本进行转发,并且仅根据样本模型输出的某些条件累积某些样本的损失。这是一个说明性代码,

for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    total_loss = 0

    loss_count_local = 0
    for i in range(len(target)):
        im = Variable(data[i].unsqueeze(0).cuda())
        y = Variable(torch.FloatTensor([target[i]]).cuda())

        out = model(im)

        # if out satisfy some condtion, we will calculate loss
        # for this sample, else proceed to next sample
        if some_condition(out):
            loss = criterion(out, y)
        else:
            continue

        total_loss += loss
        loss_count_local += 1

        if loss_count_local == 32 or i == (len(target)-1):
            total_loss /= loss_count_local
            total_loss.backward()
            total_loss = 0
            loss_count_local = 0

    optimizer.step()
Run Code Online (Sandbox Code Playgroud)

我的问题是,正如我对所有样本进行前向操作,但仅对某些样本进行后向操作。那些不造成损失的样本的图表什么时候会被释放?这些图表是否仅在 for 循环结束后或在我转发下一个示例后立即释放?我在这里有点困惑。

另外,对于那些对 做出贡献的样本total_loss,它们的图表将在我们这样做后立即释放total_loss.backward()。是对的吗?

cle*_*ros 4

让我们从 PyTorch 如何释放内存的一般讨论开始:

首先,我们应该强调 PyTorch 使用存储在 Python 对象属性中的隐式声明图。(记住,它是Python,所以一切都是对象)。更具体地说,torch.autograd.Variables 有一个.grad_fn属性。该属性的类型定义了我们拥有什么类型的计算节点(例如加法)以及该节点的输入。

这很重要,因为 Pytorch 只需使用标准 python 垃圾收集器(如果相当积极的话)即可释放内存。在这种情况下,这意味着(隐式声明的)计算图将保持活动状态,只要存在对当前作用域中保存它们的对象的引用!

这意味着,如果您对样本 s_1 ... s_k 进行某种批处理,计算每个损失并在最后添加损失,则累积损失将保存对每个单独损失的引用,而每个损失又保存对每个损失的引用计算它的计算节点的数量。

因此,应用于代码的问题更多地是关于 Python(或者更具体地说,它的垃圾收集器)如何处理引用,而不是关于 Pytorch 的处理引用。由于您将损失累积在一个对象 ( total_loss) 中,因此您可以使指针保持活动状态,从而在外循环中重新初始化该对象之前不会释放内存。

应用于您的示例,这意味着您在前向传递(在out = model(im))中创建的计算图仅由对象及其任何未来计算引用out。因此,如果您计算损失并对其求和,您将保留对out活动的引用,从而保留对计算图的引用。但是,如果您不使用它,垃圾收集器应该递归收集out及其计算图。