Keras模型的predict和predict_on_batch方法有什么区别?

Phy*_*ade 14 python deep-learning keras

根据keras 文档:

predict_on_batch(self, x)
Returns predictions for a single batch of samples.
Run Code Online (Sandbox Code Playgroud)

但是,predict在批处理调用时,标准方法似乎没有任何差别,无论是使用一个还是多个元素.

model.predict_on_batch(np.zeros((n, d_in)))
Run Code Online (Sandbox Code Playgroud)

是相同的

model.predict(np.zeros((n, d_in)))
Run Code Online (Sandbox Code Playgroud)

(一个numpy.ndarray形状(n, d_out)

GPh*_*ilo 17

不同之处在于,当您传递的x数据大于一个批次时.

predict逐批浏览所有数据,预测标签.因此,它在内部分批进行分批并一次喂食一批.

predict_on_batch另一方面,假设您传入的数据恰好是一个批次,从而将其提供给网络.它不会尝试拆分它(根据您的设置,如果阵列非常大,可能会对您的GPU内存造成问题)


小智 8

与在单个批次上执行的预测相比,predict_on_batch 似乎要快得多。

  • 批次及型号信息
    • 批量形状:(1024、333)
    • 批处理数据类型:float32
    • 模型参数:~150k
  • 时间结果:
    • 预测:~1.45 秒
    • 批量预测:~95.5 毫秒

总之,predict 方法有额外的操作来确保正确处理批次集合,而 Predict_on_batch 是应该在单个批次上使用的预测的轻量级替代方案。