我正在 CIFAR100 数据集上训练 resnet18。在大约 50 次迭代之后,验证准确度收敛在大约 34%。而训练准确率几乎达到了100%。
我怀疑它有点过拟合,所以我应用了像RandomHorizontalFlipand这样的数据增强RandomRotation,这使得验证收敛在大约 40%。
我还尝试了衰减学习率[0.1, 0.03, 0.01, 0.003, 0.001],每 50 次迭代后衰减。衰减学习率似乎并没有提高性能。
听说 CIFAR100 上的 Resnet 可以达到 70%~80% 的准确率。我还能应用什么技巧?或者我的实现有什么问题吗?同样的代码在 CIFAR10 上可以达到 80% 左右的准确率。
我的整个培训和评估代码如下:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.transforms import Compose, ToTensor, RandomHorizontalFlip, RandomRotation, Normalize
from torchvision.datasets import CIFAR10, CIFAR100
import os
from datetime import datetime
import matplotlib.pyplot as plt
def draw_loss_curve(histories, legends, save_dir):
os.makedirs(save_dir, exist_ok=True)
for key in histories[0][0].keys():
if key != "epoch":
plt.figure()
plt.title(key)
for history in histories:
x = [h["epoch"] for h in history]
y = [h[key] for h in history]
# plt.ylim(ymin=0, ymax=3.0)
plt.plot(x, y)
plt.legend(legends)
plt.savefig(os.path.join(save_dir, key + ".png"))
def cal_acc(out, label):
batch_size = label.shape[0]
pred = torch.argmax(out, dim=1)
num_true = torch.nonzero(pred == label).shape[0]
acc = num_true / batch_size
return torch.tensor(acc)
class LrManager(optim.lr_scheduler.LambdaLR):
def __init__(self, optimizer, lrs):
def f(epoch):
rate = 1
for k in sorted(lrs.keys()):
if epoch >= k:
rate = lrs[k]
else:
break
return rate
super(LrManager, self).__init__(optimizer, f)
def main(cifar=100, epochs=250, batches_show=100):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print("warning: CUDA is not available, using CPU instead")
dataset_cls = CIFAR10 if cifar == 10 else CIFAR100
dataset_train = dataset_cls(root=f"data/{dataset_cls.__name__}/", download=True, train=True,
transform=Compose([RandomHorizontalFlip(), RandomRotation(15), ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
dataset_val = dataset_cls(root=f"data/{dataset_cls.__name__}/", download=True, train=False,
transform=Compose([ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
loader_train = DataLoader(dataset_train, batch_size=128, shuffle=True)
loader_val = DataLoader(dataset_val, batch_size=128, shuffle=True)
model = resnet18(pretrained=False, num_classes=cifar).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
lr_scheduler = LrManager(optimizer, {0: 1.0, 50: 0.3, 100: 0.1, 150: 0.03, 200: 0.01})
criterion = nn.CrossEntropyLoss()
history = []
model.train()
for epoch in range(epochs):
print("------------------- TRAINING -------------------")
loss_train = 0.0
running_loss = 0.0
acc_train = 0.0
running_acc = 0.0
for batch, data in enumerate(loader_train, 1):
img, label = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
pred = model(img)
loss = criterion(pred, label)
loss.backward()
optimizer.step()
running_loss += loss.item()
loss_train += loss.item()
acc = cal_acc(pred, label)
running_acc += acc.item()
acc_train += acc.item()
if batch % batches_show == 0:
print(f"epoch: {epoch}, batch: {batch}, loss: {running_loss/batches_show:.4f}, acc: {running_acc/batches_show:.4f}")
running_loss = 0.0
running_acc = 0.0
loss_train = loss_train / batch
acc_train = acc_train / batch
lr_scheduler.step()
print("------------------- EVALUATING -------------------")
with torch.no_grad():
running_acc = 0.0
for batch, data in enumerate(loader_val, 1):
img, label = data[0].to(device), data[1].to(device)
pred = model(img)
acc = cal_acc(pred, label)
running_acc += acc.item()
acc_val = running_acc / batch
print(f"epoch: {epoch}, acc_val: {acc_val:.4f}")
history.append({"epoch": epoch, "loss_train": loss_train, "acc_train": acc_train, "acc_val": acc_val})
draw_loss_curve([history], legends=[f"resnet18-CIFAR{cifar}"], save_dir=f"history/resnet18-CIFAR{cifar}[{datetime.now()}]")
if __name__ == '__main__':
main()
Run Code Online (Sandbox Code Playgroud)
Resnet18 来自torchvision.models它的 ImageNet 实现。因为 ImageNet 样本比 CIFAR10/100 (32x32) 大得多 (224x224),所以第一层旨在积极地对输入('stem Network')进行下采样。这导致在小 CIFAR10/100 图像上丢失很多有价值的信息。
为了实现对CIFAR10良好的精度,作者使用不同的网络结构,在原来的论文中描述: https://arxiv.org/pdf/1512.03385.pdf 在这篇文章中解释说: https://towardsdatascience.com/resnets-for-cifar -10-e63e900524e0
你可以从这个 repo 下载 CIFAR10 的 resnet:https : //github.com/akamaster/pytorch_resnet_cifar10
| 归档时间: |
|
| 查看次数: |
2484 次 |
| 最近记录: |