mat*_*ang 3 neural-network flask keras tensorflow
我正在使用keras的预训练模型,调用ResNet50(weights ='imagenet')时出现错误。我在Flask服务器中有以下代码:
def getVGG16Prediction(img_path):
model = VGG16(weights='imagenet', include_top=True)
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
pred = model.predict(x)
return sort(decode_predictions(pred, top=3)[0])
def getResNet50Prediction(img_path):
model = ResNet50(weights='imagenet') #ERROR HERE
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
return decode_predictions(preds, top=3)[0]
Run Code Online (Sandbox Code Playgroud)
在main中调用时,它工作正常
if __name__ == "__main__":
STATIC_PATH = os.getcwd()+"/static"
print(getVGG16Prediction(STATIC_PATH+"/18.jpg"))
print(getResNet50Prediction(STATIC_PATH+"/18.jpg"))
Run Code Online (Sandbox Code Playgroud)
但是,当我从烧瓶POST函数调用它时,ValueError上升:
@app.route("/uploadMultipleImages", methods=["POST"])
def uploadMultipleImages():
uploaded_files = request.files.getlist("file[]")
weight = request.form.get("weight")
for file in uploaded_files:
path = os.path.join(STATIC_PATH, file.filename)
file.save(os.path.join(STATIC_PATH, file.filename))
result = getResNet50Prediction(path)
Run Code Online (Sandbox Code Playgroud)
完整的错误如下:
ValueError:Tensor(“ cond / pred_id:0”,dtype = bool)必须与Tensor(“ batchnorm / add_1:0”,shape =(?, 112,112,64),dtype = float32)来自同一张图
任何意见或建议均受到高度赞赏。谢谢。
您将需要打开不同的会话,并指定每个会话使用哪个图,否则Keras将默认替换每个图。
from tensorflow import Graph, Session, load_model
from Keras import backend as K
Run Code Online (Sandbox Code Playgroud)
加载图形:
graph1 = Graph()
with graph1.as_default():
session1 = Session()
with session1.as_default():
model = load_model(foo.h5)
graph2 = Graph()
with graph2.as_default():
session2 = Session()
with session2.as_default():
model2 = load_model(foo2.h5)
Run Code Online (Sandbox Code Playgroud)
预测/使用图形:
K.set_session(session1)
with graph1.as_default():
result = model.predict(data)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3943 次 |
| 最近记录: |