我有一个 PyTorch 计算图,它由一个执行一些计算的子图组成,然后这个计算的结果(我们称之为x
)被分支到另外两个子图。这两个子图中的每一个都会产生一些标量结果(让我们称它们为y1
和y2
)。我想对这两个结果中的每一个都做一个反向传递(即我想累加两个子图的梯度。我不想执行实际的优化步骤)。
现在,由于内存是这里的一个问题,我想按以下顺序执行操作:首先,计算x
。然后,计算y1
,并执行y1.backward()
while(这是关键点)保留通向 的图x
,但将图从 释放x
到y1
。然后,计算y2
,并执行y2.backward()
。
换句话说,为了在不牺牲太多速度的情况下节省内存,我想保留x
而不需要重新计算它,但是我想在我不再需要它们之后删除所有从x
到 的计算y1
。
问题是retain_graph
函数的参数backward()
将保留通向 的整个图y1
,而我只需要保留通向 的图的一部分x
。
这是我理想中想要的示例:
import torch
w = torch.tensor(1.0)
w.requires_grad_(True)
# sub-graph for calculating `x`
x = w+10
# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=x) # this would not work, since retain_graph is a boolean and can either retain the entire graph or free it.
# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()
Run Code Online (Sandbox Code Playgroud)
如何才能做到这一点?
该参数retain_graph
将保留整个图,而不仅仅是一个子图。但是,我们可以使用垃圾收集来释放图形中不需要的部分。通过从移除到子图中所有引用x
到y1
,该子图将被释放:
import torch
w = torch.tensor(1.0)
w.requires_grad_(True)
# sub-graph for calculating `x`
x = w+10
# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=True) # all graph is retained
# remove unneeded parts of graph. Note that these parts will be freed from memory (even if they were on GPU), due to python's garbage collection
y1 = None
x1 = None
# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
2520 次 |
最近记录: |