Mnist:获取混淆矩阵

Ahm*_*mad 2 python confusion-matrix keras tensorflow

我尝试获取 mnist 数据集的混淆矩阵。

这是我的代码:

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


model.fit(x_train, y_train, epochs=1, callbacks=[history])

test_predictions = model.predict(x_test)


# Compute confusion matrix
confusion = tf.confusion_matrix(y_test, test_predictions)
Run Code Online (Sandbox Code Playgroud)

问题是这test_prediction是一个 10000 x 10 矩阵,而 y_test 是 10000 x 1 矩阵。事实上,神经网络不会为每个测试样本提供输出。对于这种情况,如何计算混淆矩阵?

然后我如何呈现混淆矩阵?我可以为此目的使用 sci-kit 库吗?

Bis*_*dal 5

这可能是因为您的预测包含所有可能类别的概率。您需要选择概率最高的类,这将导致与 y_test 相同的维度。您可以使用 numpy 中的 argmax() 方法。它的工作原理是这样的:

import numpy as np
a = np.array([[0.9,0.1,0],[0.2,0.3,0.5],[0.4,0.6,0]])
np.argmax(a, axis=0)
array([0, 2, 1])
Run Code Online (Sandbox Code Playgroud)

您可以使用 sklearn 来生成混淆矩阵。你的代码会变成这样

from sklearn.metrics import confusion_matrix
import numpy as np

confusion = confusion_matrix(y_test, np.argmax(test_predictions,axis=1))
Run Code Online (Sandbox Code Playgroud)


dru*_*cik 5

如果您使用 .predict_classes 方法而不仅仅是预测,您将获得概率最高的类向量。

然后,您可以使用sklearn中的confusion_matrix。

test_predictions = model.predict_classes(x_test)

from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_true = y_test, y_pred = test_predictions)
print(cm)
Run Code Online (Sandbox Code Playgroud)

这里test_predictions的形状是(10000,)。

打印的结果将类似于:

array([[ 967,    1,    1,    2,    0,    1,    5,    0,    2,    1],
   [   0, 1126,    3,    1,    0,    1,    1,    0,    3,    0],
   [   3,    2, 1001,    8,    1,    0,    3,    6,    8,    0],
   [   0,    0,    1, 1002,    0,    1,    0,    1,    5,    0],
   [   3,    1,    2,    2,  955,    2,    6,    1,    3,    7],
   [   3,    1,    0,   37,    1,  833,    9,    0,    6,    2],
   [   4,    3,    1,    1,    1,    3,  941,    0,    4,    0],
   [   2,    9,    8,    5,    0,    0,    0,  988,    8,    8],
   [   3,    1,    3,   10,    3,    2,    2,    3,  946,    1],
   [   3,    8,    0,   10,    8,    8,    1,    4,    5,  962]],
  dtype=int64)
Run Code Online (Sandbox Code Playgroud)