San*_*mar 21 python precision casting deep-learning pytorch
我正在尝试在 PyTorch 中实现神经网络,但它似乎不起作用。问题似乎出在训练循环中。我花了几个小时来解决这个问题,但无法做到正确。请帮忙,谢谢。
我还没有添加数据预处理部分。
# importing libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
Run Code Online (Sandbox Code Playgroud)
# get x function (dataset related stuff)
def Getx(idx):
sample = samples[idx]
vector = Calculating_bottom(sample)
vector = torch.as_tensor(vector, dtype = torch.float64)
return vector
# get y function (dataset related stuff)
def Gety(idx):
y = np.array(train.iloc[idx, 4], dtype = np.float64)
y = torch.as_tensor(y, dtype = torch.float64)
return y
Run Code Online (Sandbox Code Playgroud)
# dataset
class mydataset(Dataset):
def __init__(self):
super().__init__()
def __getitem__(self, index):
x = Getx(index)
y = Gety(index)
return x, y
def __len__(self):
return len(train)
dataset = mydataset()
Run Code Online (Sandbox Code Playgroud)
# sample dataset value
print(dataset.__getitem__(0))
Run Code Online (Sandbox Code Playgroud)
(张量([ 5., 5., 8., 14.], dtype=torch.float64), 张量(-0.3403, dtype=torch.float64))
# data-loader
dataloader = DataLoader(dataset, batch_size = 1, shuffle = True)
Run Code Online (Sandbox Code Playgroud)
# nn architecture
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4, 4)
self.fc2 = nn.Linear(4, 2)
self.fc3 = nn.Linear(2, 1)
def forward(self, x):
x = x.float()
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Net()
Run Code Online (Sandbox Code Playgroud)
# device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
Run Code Online (Sandbox Code Playgroud)
# hyper-parameters
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
Run Code Online (Sandbox Code Playgroud)
# training loop
for epoch in range(5):
for batch in dataloader:
# unpacking
x, y = batch
x.to(device)
y.to(device)
# reset gradients
optimizer.zero_grad()
# forward propagation through the network
out = model(x)
# calculate the loss
loss = criterion(out, y)
# backpropagation
loss.backward()
# update the parameters
optimizer.step()
Run Code Online (Sandbox Code Playgroud)
错误:
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/loss.py:446: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
return F.mse_loss(input, target, reduction=self.reduction)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-18-3f68fcee9ff3> in <module>
20
21 # backpropagation
---> 22 loss.backward()
23
24 # update the parameters
/opt/conda/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
219 retain_graph=retain_graph,
220 create_graph=create_graph)
--> 221 torch.autograd.backward(self, gradient, retain_graph, create_graph)
222
223 def register_hook(self, hook):
/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
130 Variable._execution_engine.run_backward(
131 tensors, grad_tensors_, retain_graph, create_graph,
--> 132 allow_unreachable=True) # allow_unreachable flag
133
134
RuntimeError: Found dtype Double but expected Float
Run Code Online (Sandbox Code Playgroud)
Gul*_*zar 24
您需要数据的数据类型与模型的数据类型相匹配。
将模型转换为双精度(建议用于没有像您这样的严重性能问题的简单网络)
# nn architecture
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4, 4)
self.fc2 = nn.Linear(4, 2)
self.fc3 = nn.Linear(2, 1)
self.double()
Run Code Online (Sandbox Code Playgroud)
或将数据转换为浮点数。
class mydataset(Dataset):
def __init__(self):
super().__init__()
def __getitem__(self, index):
x = Getx(index)
y = Gety(index)
return x.float(), y.float()
Run Code Online (Sandbox Code Playgroud)
Abu*_*dik 14
检查“out”和“y”的数据类型
print(out.dtype)
print(y.dtype)
Run Code Online (Sandbox Code Playgroud)
你可能会发现不同之处,例如
"torch.float32"
"torch.float64"
Run Code Online (Sandbox Code Playgroud)
将它们设置为同一类型。
归档时间: |
|
查看次数: |
52205 次 |
最近记录: |