Keras - 如何使用 argmax 进行预测

nen*_*ese 4 python numpy argmax keras

我有 3 类课程Tree, Stump, Ground。我为这些类别列出了一个清单:

CATEGORIES = ["Tree", "Stump", "Ground"]
Run Code Online (Sandbox Code Playgroud)

当我打印我的预测时,它给了我输出

[[0. 1. 0.]]
Run Code Online (Sandbox Code Playgroud)

我已经阅读了关于 numpy 的 Argmax,但我不完全确定在这种情况下如何使用它。

我试过使用

print(np.argmax(prediction))
Run Code Online (Sandbox Code Playgroud)

但这给了我1. 太好了,但我想找出索引是什么1,然后打印出类别而不是最高值。

import cv2
import tensorflow as tf
import numpy as np

CATEGORIES = ["Tree", "Stump", "Ground"]


def prepare(filepath):
    IMG_SIZE = 150 # This value must be the same as the value in Part1
    img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
    return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)

# Able to load a .model, .h3, .chibai and even .dog
model = tf.keras.models.load_model("models/test.model")

prediction = model.predict([prepare('image.jpg')])
print("Predictions:")
print(prediction)
print(np.argmax(prediction))

Run Code Online (Sandbox Code Playgroud)

我希望我的预测告诉我:

Predictions:
[[0. 1. 0.]]
Stump
Run Code Online (Sandbox Code Playgroud)

感谢阅读 :) 我非常感谢任何帮助。

Mat*_*gro 7

您只需要使用以下结果对类别进行索引np.argmax

pred_name = CATEGORIES[np.argmax(prediction)]
print(pred_name)
Run Code Online (Sandbox Code Playgroud)

  • 你能告诉我数组是如何排序的吗?为什么不是[“树桩”、“树”、“地面”]?谢谢 (2认同)