警告:tensorflow:最近 11 次调用中的 11 次触发了 tf.function 回溯

Sof*_*dez 10 python warnings tensorflow jupyter-notebook

有人知道这个错误的原因吗?

WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:11 out of the last 11 calls to <function Model.make_predict_function.<locals>.predict_function at 0x000001F9D1C05EE0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:11 out of the last 11 calls to <function Model.make_predict_function.<locals>.predict_function at 0x000001F9D5604670> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
C:\Users\User\anaconda3\lib\site-packages\sklearn\cluster\_kmeans.py:973: FutureWarning: 'n_jobs' was deprecated in version 0.23 and will be removed in 0.25.
  warnings.warn("'n_jobs' was deprecated in version 0.23 and will be"
Run Code Online (Sandbox Code Playgroud)

小智 12

{TLDR}尝试更换model.predict(X)模型(x)的

我的解决方案

我也遇到了警告问题:

WARNING:tensorflow:11 out of the last 11 calls to <function Model.make_predict_function.<locals>.predict_function at 0x000001F9D1C05EE0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Run Code Online (Sandbox Code Playgroud)

我能够通过直接使用model(x)替换model.predict(x)来解决它


我遇到问题的背景信息

我正在预测时间序列,并将每个新采样时间的模型的最后一层拟合到最新数据。因此我

  1. 生成并拟合一个基础模型并冻结所有层 + 放置一个顶层
  2. 适应新数据并在循环内进行预测

我尝试使用警告和@TFer2 中建议的签名来实现自定义预测函数。然而,这产生了错误

RuntimeError: Detected a call to `Model.predict` inside a `tf.function`. `Model.predict is a high-level endpoint that manages its own `tf.function`. Please move the call to `Model.predict` outside of all enclosing `tf.function`s. Note that you can call a `Model` directly on `Tensor`s inside a `tf.function` like: `model(x)`.
Run Code Online (Sandbox Code Playgroud)

有了这个错误,我就能够解决这个问题。


小智 2

如果您调用具有相同参数类型的函数,张量流将重用之前跟踪的图,否则将创建新图。

函数通过计算 a 来确定是否重用跟踪的具体函数cache key from an input's args and kwargs

  • 为参数生成的密钥tf.Tensor是它的shapeand type输入签名
  • 为参数生成的键tf.Variable是它的id()
  • 为原语生成的密钥python是它的value
  • dicts, lists, tuples, namedtuples为嵌套, 和生成的键attrsflattened tuple.

回溯可确保张量流为每组输入生成正确的图。但价格昂贵。

你必须避免过度的回溯,否则张量流通常会发出如上所述的警告。

有几种方法可以控制跟踪行为:

  • 指定一个input_signature in tf.function
  • 指定 a[None] dimension in tf.TensorSpec以允许跟踪重用的灵活性
  • Cast python arguments to Tensors以减少回溯

有关更多详细信息,您可以参阅使用 tf.function 获得更好的性能