我在训练 Keras 模型时使用自定义指标。它工作得很好,除了输出中的指标名称model.fit_generator(...)
不可解释(注意:Tensorboard 也使用这些错误的名称)。
这是我正在做的事情的一个可重现的示例:指标使用一个参数(除了预测和真实值之外),因此我定义了一个工厂来生成无参数指标函数,类似于:
def my_dummy_metric(y_true, y_pred, the_param=1.0):
return the_param * keras.backend.ones((1))
def my_metric_factory(the_param=1.0):
def fn(y_true, y_pred):
return my_dummy_metric(y_true, y_pred, the_param=the_param)
return fn
my_second_metric = my_metric_factory(2.0)
my_other_metric = my_metric_factory(3.14)
Run Code Online (Sandbox Code Playgroud)
然后我编译并训练我的模型:
model.compile(my_optim, my_loss, [my_second_metric, my_other_metric])
history = model.fit_generator(...)
print(history.params['metrics'])
Run Code Online (Sandbox Code Playgroud)
我的麻烦是,中的指标名称history
是fn
、和。Tensorboard 也使用这些名称,您需要了解实现细节才能理解它们。fn_1
val_fn
val_fn_1
相反,当我使用简单的自定义函数,没有工厂时,我没有这个问题:
model.compile(my_optim, my_loss, [my_dummy_metric])
history = model.fit_generator(...)
print(history.params['metrics'])
Run Code Online (Sandbox Code Playgroud)
my_XXX_metric
在基于工厂的用例中是否也可以获取输出名称?
环境:使用Keras 2.2.4、TF 1.14.0、Python 3.7
rvi*_*nas 10
是的,这是可能的。在度量工厂中,只需设置适当__name__
的度量函数即可。例如:
def my_metric_factory(the_param=1.0):
def fn(y_true, y_pred):
return my_dummy_metric(y_true, y_pred, the_param=the_param)
fn.__name__ = 'metricname_{}'.format(the_param)
return fn
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
5142 次 |
最近记录: |