我有一个 PyTorch 计算图,它由一个执行一些计算的子图组成,然后这个计算的结果(我们称之为x)被分支到另外两个子图。这两个子图中的每一个都会产生一些标量结果(让我们称它们为y1和y2)。我想对这两个结果中的每一个都做一个反向传递(即我想累加两个子图的梯度。我不想执行实际的优化步骤)。
现在,由于内存是这里的一个问题,我想按以下顺序执行操作:首先,计算x。然后,计算y1,并执行y1.backward()while(这是关键点)保留通向 的图x,但将图从 释放x到y1。然后,计算y2,并执行y2.backward()。
换句话说,为了在不牺牲太多速度的情况下节省内存,我想保留x而不需要重新计算它,但是我想在我不再需要它们之后删除所有从x到 的计算y1。
问题是retain_graph函数的参数backward()将保留通向 的整个图y1,而我只需要保留通向 的图的一部分x。
这是我理想中想要的示例:
import torch
w = torch.tensor(1.0)
w.requires_grad_(True)
# sub-graph for calculating `x`
x = w+10
# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=x) # this would not work, since retain_graph is a boolean and can either retain the entire graph or free it.
# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()
Run Code Online (Sandbox Code Playgroud)
如何才能做到这一点?
该参数retain_graph将保留整个图,而不仅仅是一个子图。但是,我们可以使用垃圾收集来释放图形中不需要的部分。通过从移除到子图中所有引用x到y1,该子图将被释放:
import torch
w = torch.tensor(1.0)
w.requires_grad_(True)
# sub-graph for calculating `x`
x = w+10
# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=True) # all graph is retained
# remove unneeded parts of graph. Note that these parts will be freed from memory (even if they were on GPU), due to python's garbage collection
y1 = None
x1 = None
# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2520 次 |
| 最近记录: |