TensorFlow对象检测API奇怪的行为

Ban*_*ski 19 python classification machine-learning object-detection tensorflow

我正在使用TensorFlow的全新Object Detection API,并决定在其他一些公开的数据集上进行训练.

我偶然偶然发现这个杂货数据集,其中包括超市货架上各种品牌香烟盒的图像,以及一个文本文件,其中列出了每个图像中每个香烟盒的边界框.数据集中已标注10个主要品牌,所有其他品牌均属于第11个"杂项"类别.

我按照他们的教程设法在这个数据集上训练模型.由于处理能力的限制,我只使用了数据集的三分之一,并进行了70:30分割,用于训练和测试数据.我使用了faster_rcnn_resnet101模型.配置文件中的所有参数与TF提供的默认参数相同.

在16491个全局步骤之后,我在一些图像上测试了模型,但我对结果不太满意 -

无法在顶层检测到Camels,而在其他图像中检测到该产品

为什么它没能检测到顶行的万宝路?

我遇到的另一个问题是模型从未检测到除标签1之外的任何其他标签

未从训练数据中检测到产品的裁剪实例

即使在负像中,它也可以99%的置信度检测香烟盒!

有人可以帮我解决问题吗?我该怎么做才能提高准确度?为什么它会检测到属于第1类的所有产品,尽管我已经提到总共有11个类?

编辑添加了我的标签贴图:

item {
  id: 1
  name: '1'
}

item {
  id: 2
  name: '2'
}

item {
  id: 3
  name: '3'
}

item {
  id: 4
  name: '4'
}

item {
  id: 5
  name: '5'
}

item {
  id: 6
  name: '6'
}

item {
  id: 7
  name: '7'
}

item {
  id: 8
  name: '8'
}

item {
  id: 9
  name: '9'
}

item {
  id: 10
  name: '10'
}

item {
  id: 11
  name: '11'
}
Run Code Online (Sandbox Code Playgroud)

Ban*_*ski 14

所以我想我弄清楚发生了什么.我对数据集进行了一些分析,发现它偏向于类别1的对象.

这是每个类别的频率分布从1到11(基于0的索引)

0 10440
1 304
2 998
3 67
4 412
5 114
6 190
7 311
8 195
9 78
10 75
Run Code Online (Sandbox Code Playgroud)

我猜这个模型正在达到一个局部最小值,只是将所有标记为类别1就足够了.

关于没有检测到一些盒子的问题:我再次尝试了培训,但这次我没有区分品牌.相反,我试图教模型香烟盒是什么.它仍然没有检测到所有的盒子.

然后我决定裁剪输入图像并将其作为输入提供.只是为了看看结果是否有所改善而且确实如此!

事实证明,输入图像的尺寸远大于模型所接受的600 x 1024.因此,它将这些图像缩小到600 x 1024,这意味着香烟盒正在丢失它们的细节:)

因此,我决定测试原始模型,该模型在裁剪后的图像上进行了所有课程的训练,它就像一个魅力:)

原始图像

这是原始图像上模型的输出

左上角从原始图像裁剪

当我裁剪左上角并将其作为输入时,这是模型的输出.

谢谢大家的帮助!并祝贺TensorFlow团队为API做出了惊人的工作:)现在每个人都可以训练对象检测模型!


小智 5

数据集中有多少张图像?您拥有的训练数据越多,API 的性能就越好。我尝试在每类大约 20 张图像上训练它,准确度很差。我几乎遇到了你上面提到的所有问题。当我生成更多数据时,准确性大大提高。

PS:抱歉我没有足够的声望无法发表评论