使用单个GPU预测Keras模型的多处理

Eas*_*ang 6 multiprocessing python-2.7 keras tensorflow

背景

我想使用带有Inception-Resnet_v2的keras预测病理图像。我已经训练了模型,并得到了一个.hdf5文件。因为病理图像非常大(例如:20,000 x 20,000像素),所以我必须扫描图像以获得小的斑块进行预测。

我想使用带有python2.7的多处理库来加快预测过程。主要思想是使用不同的子流程扫描不同的线,然后发送补丁进行建模。

我看到有人建议在子流程中导入keras和加载模型。但是我认为这不适合我的任务。一次加载模型keras.models.load_model()大约需要47秒,这非常耗时。因此,我无法在每次启动新的子流程时都重新加载模型。

我的问题是我可以在主流程中加载模型并将其作为参数传递给子流程吗?

我尝试了两种方法,但它们均无效。

方法1。使用multiprocessing.Pool

代码是:

import keras
from keras.models import load_model
import multiprocessing

def predict(num,model):
    print dir(model)
    print num
    model.predict("image data, type:list")

if __name__ == '__main__':
    model = load_model("path of hdf5 file")
    list = [(1,model),(2,model),(3,model),(4,model),(5,model),(6,model)]
    pool = multiprocessing.Pool(4)
    pool.map(predict,list)
    pool.close()
    pool.join()
Run Code Online (Sandbox Code Playgroud)

输出是

cPickle.PicklingError: Can't pickle <type 'module'>: attribute lookup __builtin__.module failed
Run Code Online (Sandbox Code Playgroud)

我搜索了该错误,发现Pool无法映射无法选择的参数,因此我尝试了方法2。

方法二。使用multiprocessing.Process

该代码是

import keras
from keras.models import load_model
import multiprocessing

def predict(num,model):
    print num
    print dir(model)
    model.predict("image data, type:list")

if __name__ == '__main__':
    model = load_model("path of hdf5 file")
    list = [(1,model),(2,model),(3,model),(4,model),(5,model),(6,model)]
    proc = []
    for i in range(4):
        proc.append(multiprocessing.Process(predict, list[i]))
        proc[i].start()
    for i in range(4):
        proc[i].join()
Run Code Online (Sandbox Code Playgroud)

在方法2中,我可以打印dir(model)。我认为这意味着模型已成功传递给子流程。但是我得到了这个错误

E tensorflow/stream_executor/cuda/cuda_driver.cc:1296] failed to enqueue async memcpy from host to device: CUDA_ERROR_NOT_INITIALIZED; GPU dst: 0x13350b2200; host src: 0x2049e2400; size: 4=0x4
Run Code Online (Sandbox Code Playgroud)

我使用的环境:

  • Ubuntu 16.04,Python 2.7
  • keras 2.0.8(tensorflow后端)
  • 一个Titan X,驱动程序版本384.98,CUDA 8.0

期待回复!谢谢!

Sta*_*ham 0

也许你可以使用 apply_async() 而不是 Pool()

您可以在这里找到更多详细信息:

Python 多处理酸洗错误