我有一个keras.models.Model我加载的tf.keras.models.load_model.
现在有两种选择来使用这个模型。我可以打电话model.predict(x),也可以打电话model(x).numpy()。两个选项都给出相同的结果,但model.predict(x)运行时间要长 10 倍以上。
源代码中的注释指出:
计算是分批进行的。该方法专为大规模输入的性能而设计 。对于适合一批的少量输入,
__call__建议直接使用以加快执行速度,例如model(x), 或model(x, training=False)
我已经测试过x包含 1; 1,000,000;和 10,000,000 行,model(x)仍然表现更好。
输入需要有多大才能被归类为大规模输入,并且性能model.predict(x)更好?
您可能会发现现有的堆栈溢出答案很有用:/sf/answers/4086960951/。我在tensorflow/tensorflow#33340上找到了它。该答案建议传递experimental_run_tf_function=False到model.compile调用以恢复到模型执行的 TF 1.x 版本。您也可以完全省略该model.compile调用(预测不需要)。
输入需要有多大才能被归类为大规模输入,并且性能
model.predict(x)更好?
这是你可以测试的。正如文档所述,model(x)这可能会比model.predict(x)您的数据放入一批中更快。model.predict(x)提供的一件事model(x)是能够预测多个批次。如果您想使用 来预测多个批次model(x),则必须自己编写循环。model.predict还提供其他功能,例如回调。
仅供参考,源代码中的文档是在提交42f469be0f3e8c36624f0b01c571e7ed15f75faf中添加的,这是tensorflow/tensorflow#33340的结果。
的主要行为在这里model.predict(x)实现。它不仅仅包含模型的前向传递。这可以解释一些速度差异。
我测试过 x 包含 1; 1,000,000;10,000,000 行和 model(x) 仍然表现更好。
这 10,000,000 行适合一批吗...?
| 归档时间: |
|
| 查看次数: |
4348 次 |
| 最近记录: |