使用model.fit_generator时如何获得混淆矩阵

Hit*_*esh 11 confusion-matrix keras

我正在使用model.fit_generator来训练并获得我的二进制(两个类)模型的结果,因为我直接从我的文件夹中提供输入图像.如何在这种情况下得到混淆矩阵(TP,TN,FP,FN),因为我通常使用confusion_matrix命令sklearn.metrics来获取它,这需要predictedactual标签.但在这里我没有两者.可能是我可以从predict=model.predict_generator(validation_generator)命令计算预测标签.但我不知道我的模型是如何从我的图像中获取输入标签的.我的输入文件夹的一般结构是:

train/
 class1/
     img1.jpg
     img2.jpg
     ........
 class2/
     IMG1.jpg
     IMG2.jpg
test/
 class1/
     img1.jpg
     img2.jpg
     ........
 class2/
     IMG1.jpg
     IMG2.jpg
     ........
Run Code Online (Sandbox Code Playgroud)

我的代码的一些块是:

train_generator = train_datagen.flow_from_directory('train',  
        target_size=(50, 50),  batch_size=batch_size,
        class_mode='binary',color_mode='grayscale')  


validation_generator = test_datagen.flow_from_directory('test',
        target_size=(50, 50),batch_size=batch_size,
        class_mode='binary',color_mode='grayscale')

model.fit_generator(
        train_generator,steps_per_epoch=250 ,epochs=40,
        validation_data=validation_generator,
        validation_steps=21 )
Run Code Online (Sandbox Code Playgroud)

所以上面的代码自动接受两个类输入,但我不知道它认为是哪个类0和哪个类1.

wl2*_*776 5

我通过以下方式管理它,使用keras.utils.Sequence.

from sklearn.metrics import confusion_matrix
from keras.utils import Sequence


class MySequence(Sequence):
    def __init__(self, *args, **kwargs):
        # initialize
        # see manual on implementing methods

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        # return index-th complete batch


# create data generator
data_gen = MySequence(evaluation_set, batch_size=10) 

n_batches = len(data_gen)

confusion_matrix(
    np.concatenate([np.argmax(data_gen[i][1], axis=1) for i in range(n_batches)]),    
    np.argmax(m.predict_generator(data_gen, steps=n_batches), axis=1) 
)
Run Code Online (Sandbox Code Playgroud)

实现的类以元组形式返回批量数据,这允许不要将所有数据都保存在 RAM 中。请注意,它必须在 中实现__getitem__,并且此方法必须为相同的参数返回相同的批次。

不幸的是,这段代码迭代数据两次:第一次,它从返回的批次中创建真实答案数组,第二次它调用predict模型的方法。


bla*_*tor 3

您可以通过调用或对象class_indices上的属性来查看从类名到类索引的映射,如下所示train_generatorvalidation_generator

train_generator.class_indices