相关疑难解决方法(0)

为什么在 Pytorch 张量上调用 .numpy() 之前调用 .detach() ?

已经确定这my_tensor.detach().numpy()是从torch张量获取 numpy 数组的正确方法。

我试图更好地理解为什么。

在刚刚链接的问题的公认答案中,Blupon 指出:

您需要将您的张量转换为另一个除了其实际值定义之外不需要梯度的张量。

在他链接到的第一个讨论中,albanD 指出:

这是预期的行为,因为移动到 numpy 会破坏图形,因此不会计算梯度。

如果您实际上不需要渐变,那么您可以显式地 .detach() 需要 grad 的 Tensor 以获得具有相同内容但不需要 grad 的张量。然后可以将这个其他 Tensor 转换为一个 numpy 数组。

在他链接到的第二个讨论中,apaszke 写道:

变量不能转换为 numpy,因为它们是保存操作历史的张量的包装器,而 numpy 没有这样的对象。您可以使用 .data 属性检索变量持有的张量。然后,这应该有效:var.data.numpy()。

我研究了 PyTorch 的自动分化库的内部工作原理,但我仍然对这些答案感到困惑。为什么它会破坏图形以移动到 numpy?是否因为在 autodiff 图中不会跟踪对 numpy 数组的任何操作?

什么是变量?它与张量有什么关系?

我觉得这里需要一个彻底的高质量 Stack-Overflow 答案,向尚不了解自动分化的 PyTorch 新用户解释原因。

特别是,我认为通过一个图形来说明图形并显示在此示例中断开连接是如何发生的会很有帮助:

import torch

tensor1 = torch.tensor([1.0,2.0],requires_grad=True)

print(tensor1)
print(type(tensor1))

tensor1 = tensor1.numpy()

print(tensor1)
print(type(tensor1))
Run Code Online (Sandbox Code Playgroud)

numpy autodiff pytorch

30
推荐指数
3
解决办法
2万
查看次数

RuntimeError: 预计所有张量都在同一设备上,但发​​现至少有两个设备,cuda:0 和 cpu!恢复训练时

我在 gpu 上训练时保存了一个检查点。重新加载检查点并继续训练后,我收到以下错误。

Traceback (most recent call last):
  File "main.py", line 140, in <module>
    train(model,optimizer,train_loader,val_loader,criteria=args.criterion,epoch=epoch,batch=batch)
  File "main.py", line 71, in train
    optimizer.step()
  File "/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/optim/sgd.py", line 106, in step
    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Run Code Online (Sandbox Code Playgroud)

我的训练代码是:

def train(model,optimizer,train_loader,val_loader,criteria,epoch=0,batch=0):
    batch_count = batch
    if criteria == 'l1':
        criterion = L1_imp_Loss()
    elif criteria == 'l2':
        criterion = L2_imp_Loss() …
Run Code Online (Sandbox Code Playgroud)

python runtime-error deep-learning pytorch

5
推荐指数
4
解决办法
6102
查看次数

标签 统计

pytorch ×2

autodiff ×1

deep-learning ×1

numpy ×1

python ×1

runtime-error ×1