Tensorflow Estimator API:如何从输入函数传递参数

Sta*_*tal 6 tensorflow tensorflow-estimator

我正在尝试将类权重添加为我的模型的超参数,但是为了计算权重,我需要读取输入数据,这发生在input_fn中,然后传递给它estimator.fit().输出input_fn只是要素,标签应具有相同的形状num_examples*num_features.我的问题 - 有没有办法将数据从input_fn传播到model_fn的超参数映射?或者作为替代 - 也许有一个input_fn数据集的包装器允许过度采样少数/下采样多数以及批处理 - 在这种情况下我不需要任何参数来传播.

Sor*_*rin 1

特征和标签都可以是张量的字典(而不仅仅是一个张量)。张量可以是您想要的任何形状,但通常是 num_examples * ...

如果您不使用任何预定义的估计器,最简单的方法是添加另一个功能来计算权重,计算模型中的权重,然后使用它们(乘以损失或将其作为参数传递) 。

您还可以访问 input_fn 内的超参数,以便您可以计算其中的权重并将其添加为单独的列。

如果您使用预装估算器,请检查文档。我看到他们中的大多数都支持weight_column_name。在这种情况下,只需为其指定您在特征字典中用于权重值的名称即可。

或者,如果所有其他方法都失败,您可以在将数据输入到张量流之前按照您想要的方式对数据进行采样。