小编Enj*_*ith的帖子

微调 SSD Light torchvision

我想微调 PyTorch 中的对象检测器。为此,我使用了本教程:

https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

然而,FastRCNN 模型不适合我的用例,因此我对 SSDLight 进行了微调。我编写了这段代码来设置一个新的分类头:

from functools import partial
from torchvision.models.detection import _utils as det_utils
from torchvision.models.detection.ssdlite import SSDLiteClassificationHead
    
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)

in_channels = det_utils.retrieve_out_channels(model.backbone, (320, 320))
num_anchors = model.anchor_generator.num_anchors_per_location()
norm_layer  = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
num_classes = 2
model.head.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
Run Code Online (Sandbox Code Playgroud)

由于我的模型表现不佳,想请问社区上面的代码是否正确?

提前致谢。

python object-detection torch pytorch torchvision

6
推荐指数
1
解决办法
1856
查看次数

标签 统计

object-detection ×1

python ×1

pytorch ×1

torch ×1

torchvision ×1