Keras线程安全吗?

Bri*_*ian 20 python multithreading keras

我正在使用Python和Keras(目前正在使用Theano后端,但我对切换没有任何疑虑).我有一个神经网络,我并行加载和处理多个信息源.目前,我已经在一个单独的进程中运行每个进程,并从文件中加载自己的网络副本.这似乎浪费了RAM,所以我认为让一个多线程进程与一个所有线程使用的网络实例相比会更有效率.但是,我想知道Keras是否与后端的线程安全.如果我.predict(x)在不同的线程中同时运行两个不同的输入,我是否会遇到竞争条件或其他问题?

谢谢

Dar*_*ero 19

是的,如果你注意它,Keras是线程安全的.

实际上,在强化学习中,有一种叫做Asynchronous Advantage Actor Critics(A3C)的算法,其中每个代理依赖于相同的神经网络来告诉他们在给定状态下应该做什么.换句话说,每个线程model.predict在您的问题中同时调用.这里有Keras的示例实现.

但是,如果你查看代码,你应该特别注意这一行: model._make_predict_function() # have to initialize before threading

这在Keras文档中从未被提及,但它必须使它同时工作.简而言之,_make_predict_function是一个编译predict功能的函数.在多线程设置中,您必须predict事先手动调用此函数进行编译,否则predict在第一次运行该函数之前不会编译该函数,这会在许多线程一次调用它时出现问题.你可以在这里看到详细的解释.

到目前为止,我还没有遇到过Keras多线程的任何其他问题.

  • 我认为丹尼尔的回答更加准确和最新。 (2认同)

小智 6

引用那种fcholet

_make_predict_function 是一个私有 API。我们不建议调用它。

在这里,用户应该先简单地调用 predict。

请注意,不能保证 Keras 模型是线程安全的。考虑在每个线程中拥有模型的独立副本以进行 CPU 推理。