Fab*_*ioL 6 python machine-learning deep-learning keras tensorflow
我正在 Keras 中研究多输出模型。我实现了两个自定义指标 auroc 和 auprc,它们被传递给compile
Keras 模型的方法:
def auc(y_true, y_pred, curve='PR'):
score, up_opt = tf.compat.v1.metrics.auc(y_true, y_pred, curve=curve, summation_method="careful_interpolation")
K.get_session().run(tf.local_variables_initializer())
with tf.control_dependencies([up_opt]):
score = tf.identity(score)
return score
def auprc(y_true, y_pred):
return auc(y_true, y_pred, curve='PR')
def auroc(y_true, y_pred):
return auc(y_true, y_pred, curve='ROC')
mlp_model.compile(loss=...,
optimizer=...,
metrics=[auprc, auroc])
Run Code Online (Sandbox Code Playgroud)
使用此方法,我获得每个输出的 auprc/auroc 值,但是,为了使用贝叶斯优化器优化我的超参数,我需要一个指标(例如:每个输出的 auprc 的平均值或总和)。我不知道如何将我的指标加入到一个指标中。
编辑:这里是所需结果的示例
现在,对于每个时期,都会打印以下指标:
out1_auprc: 0.0267 - out2_auprc: 0.0277 - out3_auprc: 0.0294
Run Code Online (Sandbox Code Playgroud)
其中out1
, out2
,out3
是我的神经网络输出,我希望获得如下结果:
average_auprc: 0.0279 - out1_auprc: 0.0267 - out2_auprc: 0.0277 - out3_auprc: 0.0294
Run Code Online (Sandbox Code Playgroud)
我正在使用 Keras Tuner 进行贝叶斯优化。
如有任何帮助,我们将不胜感激,谢谢。
我覆盖了创建自定义回调的问题
class MergeMetrics(Callback):
def __init__(self,**kargs):
super(MergeMetrics,self).__init__(**kargs)
def on_epoch_begin(self,epoch, logs={}):
return
def on_epoch_end(self, epoch, logs={}):
logs['merge_metrics'] = 0.5*logs["y1_mse"]+0.5*logs["y2_mse"]
Run Code Online (Sandbox Code Playgroud)
我使用此回调来合并来自 2 个不同输出的 2 个指标。例如,我使用一个简单的问题,但您可以轻松地将其集成到您的问题中并将其与验证集集成
这是虚拟示例
X = np.random.uniform(0,1, (1000,10))
y1 = np.random.uniform(0,1, 1000)
y2 = np.random.uniform(0,1, 1000)
inp = Input((10))
x = Dense(32, activation='relu')(inp)
out1 = Dense(1, name='y1')(x)
out2 = Dense(1, name='y2')(x)
m = Model(inp, [out1,out2])
m.compile('adam','mae', metrics='mse')
checkpoint = MergeMetrics()
m.fit(X, [y1,y2], epochs=10, callbacks=[checkpoint])
Run Code Online (Sandbox Code Playgroud)
打印输出
loss: ..... y1_mse: 0.0863 - y2_mse: 0.0875 - merge_metrics: 0.0869
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
1495 次 |
最近记录: |