将sample_weights与fit_generator()一起使用

use*_*421 2 machine-learning time-series generator autoregressive-models keras

在自回归连续问题中,当零位占据过多位置时,可以将情况视为零膨胀问题(即ZIB)。换句话说,不是要拟合f(x),我们要拟合g(x)*f(x)where f(x)是我们要近似的函数,即y,并且g(x)是一个根据值是零还是非零而输出介于0和1之间的值的函数。

目前,我有两个模型。一个给我的g(x)模型,另一个给我的模型g(x)*f(x)

第一个模型给了我一组权重。这是我需要您帮助的地方。我可以将sample_weights参数与一起使用model.fit()。当我处理大量数据时,则需要使用model.fit_generator()。但是,fit_generator()没有论点sample_weights

sample_weights内部解决方案fit_generator()吗?否则,g(x)*f(x)如果我已经有训练有素的模型,我该如何适应g(x)

tod*_*day 7

您可以提供样本权重作为生成器返回的元组的第三个元素。从Keras文档上fit_generator

生成器:生成器或Sequencekeras.utils.Sequence)对象的实例,以便在使用多重处理时避免重复数据。发电机的输出必须是

  • 一个元组 (inputs, targets)
  • 一个元组(inputs, targets, sample_weights)

更新:这是生成器的粗略草图,该生成器返回输入的样本和目标以及从模型获得的样本权重g(x)

def gen(args):
    while True:
        for i in range(num_batches):
            # get the i-th batch data
            inputs = ...
            targets = ...

            # get the sample weights
            weights = g.predict(inputs)

            yield inputs, targets, weights


model.fit_generator(gen(args), steps_per_epoch=num_batches, ...)
Run Code Online (Sandbox Code Playgroud)