加载带有标签的 Tensorflow 模型

Tak*_*aka 4 python tensorflow

model.save('model')在本教程之后使用存储了一个模型:

https://towardsdatascience.com/keras-transfer-learning-for-beginners-6c9b8b7143e

标签取自目录本身。现在我想加载它并使用以下代码对图像进行预测:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.preprocessing import image

new_model = keras.models.load_model('./model/')

# Check its architecture
new_model.summary()

with image.load_img('testpics/mypic.jpg') as img: # , target_size=(32,32)) 
    img  = image.img_to_array(img)
    img  = img.reshape((1,) + img.shape)
    # img  = img/255
    # img = img.reshape(-1,784)
    img_class=new_model.predict(img) 
    prediction = img_class[0]
    classname = img_class[0]
    print("Class: ",classname)
Run Code Online (Sandbox Code Playgroud)

可悲的是输出只是

类别:[1.3706615e-03 2.9885881e-03 1.6783881e-03 3.0293325e-03 2.9168031e-03 7.2344812e-04 2.0196944e-06 2.0119224e-02 2.2996603 e-04 1.1960276e-05 3.0794670e-04 6.0808496e- 05 1.4892215e-05 1.5410941e-02 1.2452166e-04 8.2580920e-09 2.4049083e-02 3.1140331e-05 7.4609083e-01 1.5793210e-01 2.428325 6e-03 1.5755130e-04 2.4227127e-03 2.2325735e-07 7.2101393 e-06 7.6298704e-03 2.0922457e-04 1.2269774e-03 5.5882465e-06 2.4516811e-04 8.5745640e-03]

我不知道如何重新加载标签...有人可以帮我吗:/?

这是我保存的文件的屏幕截图

Tob*_*her 7

该模型不包含标签名称。因此无法通过这种方式检索。您必须在训练时保存标签,然后可以在预测阶段加载和使用它们。

我使用 pickle 将标签作为序列化数组存储在文件中。然后您可以加载它们并使用预测的 argmax 作为数组索引。

这是训练阶段:

CLASS_NAMES = ['ClassA', 'ClassB'] # should be dynamic
f = open('labels.pickle', "wb")
f.write(pickle.dumps(CLASS_NAMES))
f.close()
Run Code Online (Sandbox Code Playgroud)

并且在预测中:

CLASS_NAMES = pickle.loads(open('labels.pickle', "rb").read())
predictions = model.predict(predict_image)
result = CLASS_NAMES[predictions.argmax(axis=1)[0]]
Run Code Online (Sandbox Code Playgroud)