dkv*_*dkv 21 copy pytorch tensor
在Pytorch中似乎有几种创建张量副本的方法,包括
y = tensor.new_tensor(x) #a
y = x.clone().detach() #b
y = torch.empty_like(x).copy_(x) #c
y = torch.tensor(x) #d
Run Code Online (Sandbox Code Playgroud)
b明确优于a并d根据UserWarning如果我既可以执行我得到a或d。为什么首选它?性能?我认为它的可读性较差。
有/反对使用任何理由c?
kHa*_*hit 79
TL; 博士
使用.clone().detach()(或最好.detach().clone())
如果先分离张量然后克隆它,则不会复制计算路径,反之则复制然后放弃。因此,
.detach().clone()效率稍微高一点。-- pytorch论坛
因为它的功能稍微快速和明确。
使用perflot,我绘制了复制 pytorch 张量的各种方法的时间。
y = tensor.new_tensor(x) # method a
y = x.clone().detach() # method b
y = torch.empty_like(x).copy_(x) # method c
y = torch.tensor(x) # method d
y = x.detach().clone() # method e
Run Code Online (Sandbox Code Playgroud)
x 轴是创建的张量的维度,y 轴显示时间。该图采用线性标度。如您所见,与其他三种方法相比,tensor()ornew_tensor()需要更多时间。
注意:在多次运行中,我注意到 b、c、e 中的任何方法都可能具有最短的时间。对于 a 和 d 也是如此。但是,方法 b、c、e 始终比 a 和 d 具有更低的时间。
import torch
import perfplot
perfplot.show(
setup=lambda n: torch.randn(n),
kernels=[
lambda a: a.new_tensor(a),
lambda a: a.clone().detach(),
lambda a: torch.empty_like(a).copy_(a),
lambda a: torch.tensor(a),
lambda a: a.detach().clone(),
],
labels=["new_tensor()", "clone().detach()", "empty_like().copy()", "tensor()", "detach().clone()"],
n_range=[2 ** k for k in range(15)],
xlabel="len(a)",
logx=False,
logy=False,
title='Timing comparison for copying a pytorch tensor',
)
Run Code Online (Sandbox Code Playgroud)
Nop*_*eos 16
根据Pytorch 文档, #a 和 #b 是等效的。它还说
推荐使用 clone() 和 detach() 的等效项。
所以如果你想复制一个张量并从你应该使用的计算图中分离
y = x.clone().detach()
Run Code Online (Sandbox Code Playgroud)
因为它是最干净和最易读的方式。对于所有其他版本,存在一些隐藏的逻辑,并且计算图和梯度传播会发生什么也不是 100% 清楚。
关于#c:实际完成的工作似乎有点复杂,并且还可能引入一些开销,但我不确定。
编辑:既然在评论中被问到为什么不直接使用.clone().
与 copy_() 不同,此函数记录在计算图中。传播到克隆张量的梯度将传播到原始张量。
因此,在.clone()返回数据副本时,它会保留计算图并在其中记录克隆操作。如前所述,这将导致传播到克隆张量的梯度也传播到原始张量。这种行为会导致错误并且并不明显。由于这些可能的副作用,只有在.clone()明确需要这种行为时才应通过克隆张量。为了避免这些副作用,.detach()添加了 来断开计算图与克隆张量的连接。
由于通常对于复制操作,人们需要一个干净的副本,不会导致不可预见的副作用,因此复制张量的首选方法是.clone().detach().
检查张量是否被复制的一个示例:
import torch
def samestorage(x,y):
if x.storage().data_ptr()==y.storage().data_ptr():
print("same storage")
else:
print("different storage")
a = torch.ones((1,2), requires_grad=True)
print(a)
b = a
c = a.data
d = a.detach()
e = a.data.clone()
f = a.clone()
g = a.detach().clone()
i = torch.empty_like(a).copy_(a)
j = torch.tensor(a) # UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
print("a:",end='');samestorage(a,a)
print("b:",end='');samestorage(a,b)
print("c:",end='');samestorage(a,c)
print("d:",end='');samestorage(a,d)
print("e:",end='');samestorage(a,e)
print("f:",end='');samestorage(a,f)
print("g:",end='');samestorage(a,g)
print("i:",end='');samestorage(a,i)
Run Code Online (Sandbox Code Playgroud)
出去:
tensor([[1., 1.]], requires_grad=True)
a:same storage
b:same storage
c:same storage
d:same storage
e:different storage
f:different storage
g:different storage
i:different storage
j:different storage
Run Code Online (Sandbox Code Playgroud)
如果出现不同的存储,则复制张量。PyTorch 有近 100 个不同的构造函数,因此您可以添加更多方法。
如果我需要复制张量,我只会使用copy(),这也会复制 AD 相关信息,所以如果我需要删除 AD 相关信息,我会使用:
y = x.clone().detach()
Run Code Online (Sandbox Code Playgroud)