pytorch 收集失败,sparse_grad=True

avo*_*ado 5 python backpropagation pytorch

即使是非常简单的示例,backward()如果 也无法工作sparse_grad=True,请参阅下面的错误。

这个错误是预期的,还是我使用的gather方式错误?

In [1]: import torch as th

In [2]: x = th.rand((3,3), requires_grad=True)

# sparse_grad = False, the backward could work as expetecd
In [3]: th.gather(x @ x, 1, th.LongTensor([[0], [1]]), sparse_grad=False).sum().backward()

# sparse_grad = True, backward CANNOT work
In [4]: th.gather(x @ x, 1, th.LongTensor([[0], [1]]), sparse_grad=True).sum().backward()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
----> 1 th.gather(x @ x, 1, th.LongTensor([[0], [1]]), sparse_grad=True).sum().backward()

~/miniconda3/lib/python3.9/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    305                 create_graph=create_graph,
    306                 inputs=inputs)
--> 307         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    308
    309     def register_hook(self, hook):

~/miniconda3/lib/python3.9/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    152         retain_graph = create_graph
    153
--> 154     Variable._execution_engine.run_backward(
    155         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    156         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: sparse tensors do not have strides
Run Code Online (Sandbox Code Playgroud)

Sha*_*hai 2

我认为torch.gather不支持稀疏运算符:

torch.gather(x, 1, torch.LongTensor([[0], [1]]).to_sparse())
Run Code Online (Sandbox Code Playgroud)

结果:

NotImplementedError: Could not run 'aten::gather.out' with arguments from the 'SparseCPU' backend.
Run Code Online (Sandbox Code Playgroud)

我认为你应该在pytorch 的 github上提出问题或功能请求。