Olg*_*ova 9 deep-learning keras tensorflow
我正在构建一个用于二值图像分类的简单 CNN,从 model.evaluate() 获得的AUC 远高于从 model.predict() + roc_auc_score() 获得的 AUC。
整个笔记本都在这里。
为 model.fit() 编译模型和输出:
model.compile(loss='binary_crossentropy',
optimizer=RMSprop(lr=0.001),
metrics=['AUC'])
history = model.fit(
train_generator,
steps_per_epoch=8,
epochs=5,
verbose=1)
Run Code Online (Sandbox Code Playgroud)
Epoch 1/5 8/8 [================================] - 21s 3s/step - 损失:6.7315 - auc : 0.5143
Epoch 2/5 8/8 [================================] - 15s 2s/step - 损失:0.6626 - auc : 0.6983
Epoch 3/5 8/8 [================================] - 18s 2s/step - 损失:0.4296 - auc : 0.8777
Epoch 4/5 8/8 [================================] - 14s 2s/step - 损失:0.2330 - auc : 0.9606
Epoch 5/5 8/8 [================================] - 18s 2s/step - 损失:0.1985 - auc : 0.9767
然后 model.evaluate() 给出类似的东西:
model.evaluate(train_generator)
Run Code Online (Sandbox Code Playgroud)
9/9 [==============================] - 10 秒 1 秒/步 - 损失:0.3056 - auc:0.9956
但随后直接从 model.predict() 方法计算的 AUC 低了两倍:
from sklearn import metrics
x = model.predict(train_generator)
metrics.roc_auc_score(train_generator.labels, x)
Run Code Online (Sandbox Code Playgroud)
0.5006148007590132
我已经阅读了几篇关于类似问题的帖子(比如this、this、this以及关于 github 的广泛讨论),但它们描述了与我的案例无关的原因:
任何建议都非常感谢。谢谢!
编辑!解决方案 我在这里找到了解决方案,我只需要打电话
train_generator.reset()
Run Code Online (Sandbox Code Playgroud)
在 model.predict 之前,并在 flow_from_directory() 函数中设置 shuffle = False 。不同的原因是生成器从不同的位置开始输出批次,因此标签和预测将不匹配,因为它们涉及不同的对象。所以问题不在于评估或预测方法,而在于生成器。
编辑 2 如果使用 flow_from_directory() 创建生成器,则使用 train_generator.reset() 不方便,因为它需要在 flow_from_directory 中设置 shuffle = False,但这将在训练期间创建包含单个类的批次,从而影响学习。所以我最终在运行预测之前重新定义了 train_generator。
tensorflow.kerasAUC 通过黎曼和计算近似 AUC(曲线下面积),这与 scikit-learn 的实现不同。
如果您想使用 查找 AUC tensorflow.keras,请尝试:
import tensorflow as tf
m = tf.keras.metrics.AUC()
m.update_state(train_generator.labels, x) # assuming both have shape (N,)
r = m.result().numpy()
print(r)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1685 次 |
| 最近记录: |