为多个模型指定CPU或GPU tensorflow java的工作

Ale*_*lex 10 java gpu keras tensorflow

我正在使用Tensorflow java API(1.8.0),我加载了多个模型(在不同的会话中).使用SavedModelBundle.load(...)方法从.pb文件加载这些模型.这些.pb文件是通过保存Keras的模型获得的.

假设我想加载3个模型A,B,C.为此,我实现了一个java Model类:

public class Model implements Closeable {

private String inputName;
private String outputName;
private Session session;
private int inputSize;

public Model(String modelDir, String input_name, String output_name, int inputSize) {
    SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve");
    this.inputName = input_name;
    this.outputName = output_name;
    this.inputSize = inputSize;
    this.session = bundle.session();
}

public void close() {
    session.close();
}

public Tensor predict(Tensor t) {
    return session.runner().feed(inputName, t).fetch(outputName).run().get(0);
}
}
Run Code Online (Sandbox Code Playgroud)

然后我可以轻松地使用此类实例化与我的A,B和C模型相对应的3个Model对象,并在同一个java程序中使用这3个模型进行预测.我还注意到如果我有一个GPU,就会加载3个模型.

但是,我只希望模型A在GPU上运行并强制其他2个在CPU上运行.

通过阅读文档并深入了解源代码,我没有找到方法.我试图将可见设备的新ConfigProto定义为None,并使用图形实例化一个新的Session,但它不起作用(参见下面的代码).

    public Model(String modelDir, String input_name, String output_name, int inputSize) {
      SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve");
      this.inputName = input_name;
      this.outputName = output_name;
      this.inputSize = inputSize;
      ConfigProto configProto = ConfigProto.newBuilder().setAllowSoftPlacement(false).setGpuOptions(GPUOptions.newBuilder().setVisibleDeviceList("").build()).build();
      this.session = new Session(bundle.graph(),configProto.toByteArray());
}
Run Code Online (Sandbox Code Playgroud)

当我加载模型时,它使用可用的GPU.你对这个问题有什么解决方案吗?

谢谢您的回答.

小智 1

根据这个问题,新的源代码修复了这个问题。不幸的是,您必须按照这些说明从源代码构建

然后你可以测试:

ConfigProto configProto = ConfigProto.newBuilder()
                .setAllowSoftPlacement(true) // allow less GPUs than configured
                .setGpuOptions(GPUOptions.newBuilder().setPerProcessGpuMemoryFraction(0.01).build())
                .build();
SavedModelBundle  bundle = SavedModelBundle.loader(modelDir).withTags("serve").withConfigProto(configProto.toByteArray()).load();
Run Code Online (Sandbox Code Playgroud)