我想微调 PyTorch 中的对象检测器。为此,我使用了本教程:
https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
然而,FastRCNN 模型不适合我的用例,因此我对 SSDLight 进行了微调。我编写了这段代码来设置一个新的分类头:
from functools import partial
from torchvision.models.detection import _utils as det_utils
from torchvision.models.detection.ssdlite import SSDLiteClassificationHead
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)
in_channels = det_utils.retrieve_out_channels(model.backbone, (320, 320))
num_anchors = model.anchor_generator.num_anchors_per_location()
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
num_classes = 2
model.head.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
Run Code Online (Sandbox Code Playgroud)
由于我的模型表现不佳,想请问社区上面的代码是否正确?
提前致谢。