何时在张量流中使用 model.predict(x) 与 model(x)

Gun*_*rsi 5 keras tensorflow

我有一个keras.models.Model我加载的tf.keras.models.load_model.

现在有两种选择来使用这个模型。我可以打电话model.predict(x),也可以打电话model(x).numpy()。两个选项都给出相同的结果,但model.predict(x)运行时间要长 10 倍以上。

源代码中的注释指出:

计算是分批进行的。该方法专为大规模输入的性能而设计 。对于适合一批的少量输入,__call__建议直接使用以加快执行速度,例如 model(x), 或model(x, training=False)

我已经测试过x包含 1; 1,000,000;和 10,000,000 行,model(x)仍然表现更好。

输入需要有多大才能被归类为大规模输入,并且性能model.predict(x)更好?

Jak*_*kub 3

您可能会发现现有的堆栈溢出答案很有用:/sf/answers/4086960951/。我在tensorflow/tensorflow#33340上找到了它。该答案建议传递experimental_run_tf_function=Falsemodel.compile调用以恢复到模型执行的 TF 1.x 版本。您也可以完全省略该model.compile调用(预测不需要)。

输入需要有多大才能被归类为大规模输入,并且性能model.predict(x)更好?

这是你可以测试的。正如文档所述,model(x)这可能会比model.predict(x)您的数据放入一批中更快。model.predict(x)提供的一件事model(x)是能够预测多个批次。如果您想使用 来预测多个批次model(x),则必须自己编写循环。model.predict还提供其他功能,例如回调。

仅供参考,源代码中的文档是在提交42f469be0f3e8c36624f0b01c571e7ed15f75faf中添加的,这是tensorflow/tensorflow#33340的结果。

的主要行为在这里model.predict(x)实现。它不仅仅包含模型的前向传递。这可以解释一些速度差异。

我测试过 x 包含 1; 1,000,000;10,000,000 行和 model(x) 仍然表现更好。

这 10,000,000 行适合一批吗...?