小编Gau*_*sra的帖子

torch.no_grad() 的目的是什么:

考虑以下使用 PyTorch 实现的线性回归代码:

X是输入,Y是训练集的输出,w是需要优化的参数
import torch

X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)

def forward(x):
    return w * x

def loss(y, y_pred):
    return ((y_pred - y)**2).mean()

print(f'Prediction before training: f(5) = {forward(5).item():.3f}')

learning_rate = 0.01
n_iters = 100

for epoch in range(n_iters):
    # predict = forward pass
    y_pred = forward(X)

    # loss
    l = loss(Y, y_pred)

    # calculate gradients = backward pass
    l.backward()

    # update weights
    #w.data = …
Run Code Online (Sandbox Code Playgroud)

python gradient machine-learning linear-regression pytorch

9
推荐指数
2
解决办法
2万
查看次数