keras模型部署

server.py

#! -*- coding: utf-8 -*-
from tensorflow.python.client import device_lib
from tensorflow import keras as k
from PIL import Image
import tensorflow as tf
import numpy as np
import base64
import flask
import json
import sys
import io

MODEL_PATH = "/xxx/xxx.hdf5"
IMG_W = 224
IMG_H = 224
IMG_C = 3

gpu_config = tf.GPUOptions(
    allow_growth=True,
    # per_process_gpu_memory_fraction=0.99,
)
gpu_config = tf.ConfigProto(
    log_device_placement=False,
    allow_soft_placement=True,
    gpu_options=gpu_config,
)

local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
print("GPUS:", num_gpus)

# load model
print("* Loading model ...")
model = k.models.load_model(MODEL_PATH)
if num_gpus > 1:
    model = k.utils.multi_gpu_model(model, num_gpus)
model._make_predict_function()
print("* Model loaded.")

sess = tf.Session(config=gpu_config)
k.backend.set_session(sess)


def prepare_image(image, target):
    if image.mode != "RGB":
        image = image.convert("RGB")
    image = image.resize(target)
    image = np.asarray(image, "float32")
    image = np.expand_dims(image, axis=0)
    return image


def base64_decode_image(a):
    if sys.version_info.major == 3:
        a = bytes(a, encoding="utf-8")
    return base64.b64decode(a)


app = flask.Flask(__name__)


@app.route("/predict", methods=["POST"])
def predict():
    data = {"success": False}
    if flask.request.method == "POST":
        if flask.request.files.get("image"):
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image))
            image = prepare_image(image, (IMG_W, IMG_H))
            print("* inputs shape:", image.shape)
            outputs = model.predict(image)
            preds = np.argmax(outputs, axis=1)
            results = []
            for (pred, output) in zip(preds, outputs):
                r = {"pred": int(pred), "score": float(output[pred])}
                print(r)
                results.append(r)
            data["results"] = results
        else:
            data["err_msg"] = "Get none."
    else:
        data["err_msg"] = "Not support GET method."
    return json.dumps(data)


@app.route("/server", methods=["POST"])
def server():
    data = {"success": False}
    if flask.request.method == "POST":
        if flask.request.get_data():
            post_data = flask.request.get_data()
            post_data = json.loads(post_data)
            image = post_data["image"]
            image = base64_decode_image(image)
            image = Image.open(io.BytesIO(image))
            image = prepare_image(image, (IMG_W, IMG_H))

            print("* inputs shape:", image.shape)
            outputs = model.predict(image)

            preds = np.argmax(outputs, axis=1)
            results = []
            for (pred, output) in zip(preds, outputs):
                r = {"pred": int(pred), "score": float(output[pred])}
                print(r)
                results.append(r)
            data["success"] = True
            data["results"] = results
        else:
            data["err_msg"] = "Get none."
    else:
        data["err_msg"] = "Not support GET method."
    return json.dumps(data)


if __name__ == '__main__':
    app.run()

client.py

#! -*- coding: utf-8 -*-
from urllib import request
import base64
import json

IMAGE_PATH = "test.png"
HOST = "http://127.0.0.1:5000/"


def base64_encode_image(a):
    return base64.b64encode(a).decode("utf-8")


def main():
    image = open(IMAGE_PATH, "rb").read()
    image = base64_encode_image(image)
    data = {"image": image}
    data = json.dumps(data)
    req = request.Request(HOST + "server", headers={"Content-Type": "application/json"})
    # bytes(data, encoding="utf-8")
    res = request.urlopen(req, data=bytes(data, encoding="utf-8"))
    result = json.loads(res.read())
    print(result)


if __name__ == '__main__':
    main()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章