我是 Python 和 PyTorch 的学生和初学者。我有一个非常基本的神经网络,我遇到了提到的 RunTimeError。重现错误的代码是这样的:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
# Ensure Reproducibility
torch.manual_seed(0)
# Data Generation
x = torch.randn((100,1), requires_grad = True)
y = 1 + 2 * x + 0.3 * torch.randn(100,1)
# Shuffles the indices
idx = np.arange(100)
np.random.shuffle(idx)
# Uses first 80 random indices for train
train_idx = idx[:70]
# Uses the remaining indices for validation
val_idx = idx[70:]
# Generates train …
Run Code Online (Sandbox Code Playgroud)