Hel*_*688 5 python deep-learning conv-neural-network cross-entropy pytorch
*My Training Model*
def train(model,criterion,optimizer,iters):
epoch = iters
train_loss = []
validaion_loss = []
train_acc = []
validation_acc = []
states = ['Train','Valid']
for epoch in range(epochs):
print("epoch : {}/{}".format(epoch+1,epochs))
for phase in states:
if phase == 'Train':
model.train() *training the data if phase is train*
dataload = train_data_loader
else:
model.eval()
dataload = valid_data_loader
run_loss,run_acc = 0,0 *creating variables to calculate loss and acc*
for data in dataload:
inputs,labels = data
inputs = inputs.to(device)
labels = labels.to(device)
labels = labels.byte()
optimizer.zero_grad() #Using the optimizer
with torch.set_grad_enabled(phase == 'Train'):
outputs = model(inputs)
loss = criterion(outputs,labels.unsqueeze(1).float())
predict = outputs>=0.5
if phase == 'Train':
loss.backward() #backward propagation
optimizer.step()
acc = torch.sum(predict == labels.unsqueeze(1))
run_loss+=loss.item()
run_acc+=acc.item()/len(labels)
if phase == 'Train': #calulating train loss and accucracy
epoch_loss = run_loss/len(train_data_loader)
train_loss.append(epoch_loss)
epoch_acc = run_acc/len(train_data_loader)
train_acc.append(epoch_acc)
else: #training validation loss and accuracy
epoch_loss = run_loss/len(valid_data_loader)
validaion_loss.append(epoch_loss)
epoch_acc = run_acc/len(valid_data_loader)
validation_acc.append(epoch_acc)
print("{}, loss :{},accuracy:{}".format(phase,epoch_loss,epoch_acc))
history = {'Train_loss':train_loss,'Train_accuracy':train_acc,
'Validation_loss':validaion_loss,'Validation_Accuracy':validation_acc}
return model,history[enter image description here][1]
Run Code Online (Sandbox Code Playgroud)
我遇到了错误,因为需要 0D 或 1D 目标张量,不支持多目标,您能否帮助纠正上述代码。参考了之前的相关文章,但未能得到想要的结果。我必须更改哪些代码片段才能使我的模型成功运行。任何建议都是受欢迎的。提前致谢。
您的问题是标签具有正确的形状来计算损失。当您添加.unsqueeze(1)到标签时,您使用此形状[32,1]制作标签,这与计算损失的要求不一致。
要解决此问题,您只需删除.unsqueeze(1)for 标签即可。
如果您阅读CrossEntropLoss的文档,则参数:
outputs你的情况和[32,3]。labels你的情况,应该是 [32]。因此,损失函数预计labels为1D target not multi-target。