from tensorflow.python.client import device_lib
from tensorflow import keras as k
from PIL import Image
import tensorflow as tf
import numpy as np
import base64
import flask
import json
import sys
import io
MODEL_PATH = "/xxx/xxx.hdf5"
IMG_W = 224
IMG_H = 224
IMG_C = 3
gpu_config = tf.GPUOptions(
allow_growth=True,
)
gpu_config = tf.ConfigProto(
log_device_placement=False,
allow_soft_placement=True,
gpu_options=gpu_config,
)
local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
print("GPUS:", num_gpus)
print("* Loading model ...")
model = k.models.load_model(MODEL_PATH)
if num_gpus > 1:
model = k.utils.multi_gpu_model(model, num_gpus)
model._make_predict_function()
print("* Model loaded.")
sess = tf.Session(config=gpu_config)
k.backend.set_session(sess)
def prepare_image(image, target):
if image.mode != "RGB":
image = image.convert("RGB")
image = image.resize(target)
image = np.asarray(image, "float32")
image = np.expand_dims(image, axis=0)
return image
def base64_decode_image(a):
if sys.version_info.major == 3:
a = bytes(a, encoding="utf-8")
return base64.b64decode(a)
app = flask.Flask(__name__)
@app.route("/predict", methods=["POST"])
def predict():
data = {"success": False}
if flask.request.method == "POST":
if flask.request.files.get("image"):
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
image = prepare_image(image, (IMG_W, IMG_H))
print("* inputs shape:", image.shape)
outputs = model.predict(image)
preds = np.argmax(outputs, axis=1)
results = []
for (pred, output) in zip(preds, outputs):
r = {"pred": int(pred), "score": float(output[pred])}
print(r)
results.append(r)
data["results"] = results
else:
data["err_msg"] = "Get none."
else:
data["err_msg"] = "Not support GET method."
return json.dumps(data)
@app.route("/server", methods=["POST"])
def server():
data = {"success": False}
if flask.request.method == "POST":
if flask.request.get_data():
post_data = flask.request.get_data()
post_data = json.loads(post_data)
image = post_data["image"]
image = base64_decode_image(image)
image = Image.open(io.BytesIO(image))
image = prepare_image(image, (IMG_W, IMG_H))
print("* inputs shape:", image.shape)
outputs = model.predict(image)
preds = np.argmax(outputs, axis=1)
results = []
for (pred, output) in zip(preds, outputs):
r = {"pred": int(pred), "score": float(output[pred])}
print(r)
results.append(r)
data["success"] = True
data["results"] = results
else:
data["err_msg"] = "Get none."
else:
data["err_msg"] = "Not support GET method."
return json.dumps(data)
if __name__ == '__main__':
app.run()
from urllib import request
import base64
import json
IMAGE_PATH = "test.png"
HOST = "http://127.0.0.1:5000/"
def base64_encode_image(a):
return base64.b64encode(a).decode("utf-8")
def main():
image = open(IMAGE_PATH, "rb").read()
image = base64_encode_image(image)
data = {"image": image}
data = json.dumps(data)
req = request.Request(HOST + "server", headers={"Content-Type": "application/json"})
res = request.urlopen(req, data=bytes(data, encoding="utf-8"))
result = json.loads(res.read())
print(result)
if __name__ == '__main__':
main()