如何从 json 注释图像制作 tfrecords

Muh*_*han 4 python json object-detection tensorflow object-detection-api

我已经使用 tensorflow object detection-api 来训练我自己的对象检测器。但当时,图像是使用labelimg注释的,它为每个图像创建 xml 文件。现在我得到了每个图像都有 json 文件的标记图像。那么如何我使用这些 json 文件来创建 tfrecords。

Muh*_*han 5

首先,我使用自己的脚本创建了 csv 文件。

import os
import glob
import pandas as pd
import json
import pickle

def json_to_csv():
    path_to_json = 'images/train/'
    json_files = [pos_json for pos_json in os.listdir(path_to_json) if pos_json.endswith('.json')]
    path_to_jpeg = 'images/train/'
    jpeg_files = [pos_jpeg for pos_jpeg in os.listdir(path_to_jpeg) if pos_jpeg.endswith('.jpeg')]
    fjpeg=(list(reversed(jpeg_files)))
    n=0
    csv_list = []
    labels=[]
    for j in json_files:
        data_file=open('images/train/{}'.format(j))   
        data = json.load(data_file)
        width,height=data['display_width'],data['display_height']
        for item in data["items"]:
            box = item['bounding_box']
            if item['upc']!='None':
                name=item['upc']
                labels.append(name)
                xmin=box['left']
                ymin=box['top']
                xmax=box['right']
                ymax=box['bottom']
                value = (fjpeg[n],
                         width,
                         height,
                         name,
                         xmin,
                         ymin,
                         xmax,
                         ymax
                         )
                csv_list.append(value)
          n=n+1
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
    csv_df = pd.DataFrame(csv_list, columns=column_name)
    labels_train=list(set(labels))
    with open("train_labels.txt", "wb") as fp:   #Pickling
        pickle.dump(labels_train, fp)
    return csv_df

def main():
    for directory in ['train']:
        csv_df = json_to_csv()
        csv_df.to_csv('data/{}_labels.csv'.format(directory), index=None)
        print('Successfully converted json to csv.')

main()
Run Code Online (Sandbox Code Playgroud)

然后我使用这个脚本来创建 tfrecords。