Tensorflow 2.0:从回调访问批次的张量

fra*_*isr 10 python keras tensorflow tf.keras tensorflow2.0

我正在使用 Tensorflow 2.0 并尝试编写一个tf.keras.callbacks.Callback读取model批处理的输入和输出。

我希望能够覆盖on_batch_end和访问model.inputsmodel.outputs但它们没有EagerTensor我可以访问的值。有没有办法访问批处理中涉及的实际张量值?

这有许多实际用途,例如将这些张量输出到 Tensorboard 以进行调试,或将它们序列化以用于其他目的。我知道我可以使用再次运行整个模型,model.predict但这将迫使我通过网络运行每个输入两次(而且我可能还有非确定性数据生成器)。关于如何实现这一目标的任何想法?

Yao*_*ang 1

不,无法在回调中访问输入和输出的实际值。这不仅仅是回调设计目标的一部分。回调只能访问模型、适合的参数、纪元号和一些指标值。正如您所发现的, model.input 和 model.output 仅指向符号 KerasTensors,而不是实际值。

要执行您想要的操作,您可以获取输入,将其(可能使用 RaggedTensor)与您关心的输出堆叠在一起,然后将其作为模型的额外输出。然后将您的功能实现为仅读取 y_pred 的自定义指标。在您的指标中,解开 y_pred 以获取输入和输出,然后可视化/序列化/等。

另一种方法可能是实现一个自定义层,该层使用 py_function 在 python 中调用回函数。这在严格的训练期间会非常慢,但对于诊断/调试期间可能足够使用。