警告:tensorflow:sample_weight 模式被强制从 ... 到 ['...']

jor*_*mit 61 python keras tensorflow tf.keras tensorflow2.0

使用.fit_generator()or训练图像分类器.fit()并将字典传递给class_weight=作为参数。

我在 TF1.x 中从未出错,但在 2.1 中,我在开始训练时得到以下输出:

WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']
Run Code Online (Sandbox Code Playgroud)

将某物从 强制...到是什么意思['...']

了解关于这一警告源tensorflow的回购是在这里,摆放的意见是:

尝试将 sample_weight_modes 强制为目标结构。这隐含地取决于模型将其内部表示的输出展平的事实。

jlh*_*jlh 22

这似乎是一条虚假消息。升级到 TensorFlow 2.1 后,我收到相同的警告消息,但我根本不使用任何类权重或样本权重。我确实使用了一个返回这样的元组的生成器:

return inputs, targets
Run Code Online (Sandbox Code Playgroud)

现在我只是将其更改为以下内容以使警告消失:

return inputs, targets, [None]
Run Code Online (Sandbox Code Playgroud)

我不知道这是否相关,但我的模型使用 3 个输入,所以我的inputs变量实际上是 3 个 numpy 数组的列表。 targets只是一个 numpy 数组。

无论如何,这只是一个警告。无论哪种方式,培训都可以正常工作。

针对 TensorFlow 2.2 进行编辑:

这个 bug 似乎在 TensorFlow 2.2 中已经修复了,这很棒。然而,上面的修复将在 TF 2.2 中失败,因为它会尝试获取样本权重的形状,这显然会失败AttributeError: 'NoneType' object has no attribute 'shape'。所以在升级到 2.2 时撤消上述修复。


Max*_*Max 18

我相信这是 tensorflow 的一个错误,当您model.compile()使用默认参数sample_weight_mode=None调用然后model.fit()使用指定的sample_weightclass_weight.

从张量流回购:

  • fit() 最终打电话 _process_training_inputs()
  • _process_training_inputs() 设置 sample_weight_modes = [None]基于model.sample_weight_mode = None然后创建一个DataAdapterwithsample_weight_modes = [None]
  • DataAdapter通话broadcast_sample_weight_modes()sample_weight_modes = [None]初始化
  • broadcast_sample_weight_modes() 似乎期待 sample_weight_modes = None但收到[None]
  • 它断言这[None]是与sample_weight/不同的结构class_weightNone通过拟合sample_weight/的结构将其覆盖回,class_weight并输出警告

警告抛开这有没有影响fit()作为sample_weight_modesDataAdapter重新设置为None

请注意,tensorflow文档指出它sample_weight必须是一个 numpy 数组。如果改为调用fit()with sample_weight.tolist(),则不会收到警告,但会sample_weight被静默覆盖到Nonewhen_process_numpy_inputs()预处理中调用并接收长度大于 1 的输入。

  • 非常详尽的解释,谢谢。我唯一不明白的是,警告描述了“...”被强制为“[...]”,而在你的情况下,“[None]”被强制为“None”... (2认同)

Ten*_*ort 7

我已经采用了您的 Gist 并安装了 Tensorflow 2.0,而不是 TFA,并且它在没有任何此类警告的情况下工作。

这是完整代码的要点。安装 Tensorflow 的代码如下所示:

!pip install tensorflow==2.0
Run Code Online (Sandbox Code Playgroud)

执行成功截图如下:

在此处输入图片说明

更新:此错误已在Tensorflow Version 2.2.

  • 感谢您的答复。你是对的,直到版本`2.1.0rc0`才引入警告消息。然而,恐怕我的问题仍然存在:“将某些内容从‘...’强制为‘['...']`意味着什么?” (5认同)
  • 我注意到,当“sample_weight_mode=None”和“target_struct”的类型为“dict”时,可能会发生一些意想不到的事情,“sample_weight_modes”则为“[None]”,并且由于“dict”而捕获了“broadcast_sample_weight_modes”中的异常。这可以被视为一个错误吗? (3认同)
  • 没有。问题不断收集观点和点赞,但没有答案。 (3认同)