无法将 pytorch 张量发送到 cuda

Gui*_*ier 1 python gpu pytorch

我创建了一个火炬张量,我希望它进入 GPU,但它没有。这太破了。怎么了?

def test_model_works_on_gpu():
    with torch.cuda.device(0) as cuda:
        some_random_d_model = 2 ** 9
        five_sentences_of_twenty_words = torch.from_numpy(np.random.random((5, 20, T * d))).float()
        five_sentences_of_twenty_words_mask = torch.from_numpy(np.ones((5, 1, 20))).float()
        pytorch_model = make_sentence_model(d_model=some_random_d_model, T_sgnn=T, d_sgnn=d)

        five_sentences_of_twenty_words.to(cuda)
        five_sentences_of_twenty_words_mask.to(cuda)
        print(type(five_sentences_of_twenty_words), type(five_sentences_of_twenty_words_mask))
        print(five_sentences_of_twenty_words.is_cuda, five_sentences_of_twenty_words_mask.is_cuda)
        pytorch_model.to(cuda)
        output_before_match = pytorch_model(five_sentences_of_twenty_words, five_sentences_of_twenty_words_mask)

        assert output_before_match.shape == (5, some_random_d_model)
        print(type(output_before_match))
        print(output_before_match.is_cuda, output_before_match.get_device())
Run Code Online (Sandbox Code Playgroud)
tests/test_model.py:58: RuntimeError

<class 'torch.Tensor'> <class 'torch.Tensor'>
False False
<class 'torch.Tensor'>

>       print(output_before_match.is_cuda, output_before_match.get_device())
E       RuntimeError: get_device is not implemented for tensors with CPU backend
Run Code Online (Sandbox Code Playgroud)

还:

>>> torch.cuda.is_available()
True
>>> torch.cuda.device_count()
2
Run Code Online (Sandbox Code Playgroud)

和:

pip freeze | grep -i torch
torch==1.0.0
torchvision==0.2.1
Run Code Online (Sandbox Code Playgroud)

uke*_*emi 6

您的问题是以下几行:

five_sentences_of_twenty_words.to(cuda)
five_sentences_of_twenty_words_mask.to(cuda)
Run Code Online (Sandbox Code Playgroud)

.to(device)仅当应用于模型时才起作用。

当应用于张量时,必须对其进行赋值:

five_sentences_of_twenty_words = five_sentences_of_twenty_words.to(cuda)
five_sentences_of_twenty_words_mask = five_sentences_of_twenty_words_mask.to(cuda)
Run Code Online (Sandbox Code Playgroud)