将类标签附加到Keras模型

Cer*_*rno 4 python keras

我正在使用Keras Sequential模型来训练许多多类分类器.

在评估时,Keras输出置信度向量,我可以使用argmax推断出正确的类ID.然后我可以使用查找表来接收实际的类标签(例如字符串).

到目前为止,解决方案是加载训练的模型,然后单独加载查找表.由于我有很多分类器,我宁愿将两个结构保存在一个文件中.

所以我正在寻找的是一种将实际标签查找向量集成到Keras模型中的方法.这将允许我有一个分类器文件,它能够获取一些输入数据并返回该数据的正确类标签.

解决这个问题的一种方法是将模型和查找表存储在一个元组中,并将该元组写入一个pickle,但这似乎并不优雅.

Cer*_*rno 7

所以我亲自尝试了一个解决方案,这似乎有效.我希望有更简单的东西.

我认为第二次打开模型文件并不是最佳选择.如果有人能做得更好,无论如何都要做.

import h5py

from keras.models import load_model
from keras.models import save_model


def load_model_ext(filepath, custom_objects=None):
    model = load_model(filepath, custom_objects=None)
    f = h5py.File(filepath, mode='r')
    meta_data = None
    if 'my_meta_data' in f.attrs:
        meta_data = f.attrs.get('my_meta_data')
    f.close()
    return model, meta_data


def save_model_ext(model, filepath, overwrite=True, meta_data=None):
    save_model(model, filepath, overwrite)
    if meta_data is not None:
        f = h5py.File(filepath, mode='a')
        f.attrs['my_meta_data'] = meta_data
        f.close()
Run Code Online (Sandbox Code Playgroud)