自定义指标访问X输入数据

Lar*_*ado 5 metrics keras tensorflow

I'd like to write a custom metric for a spelling correction model that counts correctly substituted letters that were previously incorrect. And it should be counted incorrectly substituted letters that were previously correct.

这就是为什么我需要访问 x_input 数据。不幸的是,默认情况下只能访问 y_true 和 y_pred。是否有解决方法可以获取匹配的 x_input?

是:

def custom_metric(y_true, y_pred):
Run Code Online (Sandbox Code Playgroud)

通缉:

def custom_metric(x_input, y_true, y_pred):
Run Code Online (Sandbox Code Playgroud)

use*_*882 2

def custom_loss(x_input):
    def loss_fn(y_true, y_pred):
        # Use your x_input here directly
        return #Your loss value
    return loss_fn

model = # Define your model
model.compile(loss=custom_loss(x_input))   
# Values of y_true and y_pred will be passed implicitly by Keras
Run Code Online (Sandbox Code Playgroud)

请记住,x_input在训练模型时,所有批次的输入都将具有相同的值。

编辑

由于您x_input需要每批次的数据来在损失函数期间进行估计,并且您有自己的自定义损失函数,因此为什么不传递作为标签。像这样的东西:x_input

model.fit(x=x_input, y=x_input)
model.compile(loss=custom_loss())

def custom_loss(y_true, y_pred):
  # y_true corresponds to x_input data
Run Code Online (Sandbox Code Playgroud)

如果你需要x_input并且需要传递一些其他数据,你可以这样做:

model.fit(x=x_input, y=[x_input, other_data])
model.compile(loss=custom_loss())
Run Code Online (Sandbox Code Playgroud)

您现在只需要解耦数据即可y_true