import ast
import numpy as np
from PIL import Image
from collections import OrderedDict
from model_service.pytorch_model_service import PTServingBaseService
import torch
from torchvision import transforms
input_size = 380
test_transforms = transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class garbage_classify_service(PTServingBaseService):
def __init__(self, model_name, model_path):
print('model_name: ', model_name)
print('model_path: ', model_path)
# these three parameters are no need to modify
self.model_name = model_name
# self.model_path = model_path
# self.signature_key = 'predict_images'
# add the input and output key of your pb model here,
# these keys are defined when you save a pb file
self.input_key_1 = 'input_img'
map_location = torch.device('cpu')
self.model = torch.load(model_path, map_location=map_location)
self.model.eval()
self.idx_to_cls = {'0': 0, '1': 1, '2': 10, '3': 11, '4': 12, '5': 13, '6': 14, '7': 15, '8': 16, '9': 17,
'10': 18, '11': 19, '12': 2, '13': 20, '14': 21, '15': 22, '16': 23, '17': 24, '18': 25,
'19': 26, '20': 27, '21': 28, '22': 29, '23': 3, '24': 30, '25': 31, '26': 32, '27': 33,
'28': 34, '29': 35, '30': 36, '31': 37, '32': 38, '33': 39, '34': 4, '35': 40, '36': 41,
'37': 42, '38': 43, '39': 44, '40': 45, '41': 46, '42': 47, '43': 48, '44': 49, '45': 5,
'46': 50, '47': 51, '48': 52, '49': 53, '50': 6, '51': 7, '52': 8, '53': 9}
self.label_id_name_dict = \
{
"0": "工藝品/仿唐三彩",
"1": "工藝品/仿宋木葉盞",
"2": "工藝品/布貼繡",
"3": "工藝品/景泰藍",
"4": "工藝品/木馬勺臉譜",
"5": "工藝品/柳編",
"6": "工藝品/葡萄花鳥紋銀香囊",
"7": "工藝品/西安剪紙",
"8": "工藝品/陝歷博唐妞系列",
"9": "景點/關中書院",
"10": "景點/兵馬俑",
"11": "景點/南五臺",
"12": "景點/大興善寺",
"13": "景點/大觀樓",
"14": "景點/大雁塔",
"15": "景點/小雁塔",
"16": "景點/未央宮城牆遺址",
"17": "景點/水陸庵壁塑",
"18": "景點/漢長安城遺址",
"19": "景點/西安城牆",
"20": "景點/鐘樓",
"21": "景點/長安華嚴寺",
"22": "景點/阿房宮遺址",
"23": "民俗/嗩吶",
"24": "民俗/皮影",
"25": "特產/臨潼火晶柿子",
"26": "特產/山茱萸",
"27": "特產/玉器",
"28": "特產/閻良甜瓜",
"29": "特產/陝北紅小豆",
"30": "特產/高陵冬棗",
"31": "美食/八寶玫瑰鏡糕",
"32": "美食/涼皮",
"33": "美食/涼魚",
"34": "美食/德懋恭水晶餅",
"35": "美食/攪團",
"36": "美食/枸杞燉銀耳",
"37": "美食/柿子餅",
"38": "美食/漿水面",
"39": "美食/灌湯包",
"40": "美食/燒肘子",
"41": "美食/石子餅",
"42": "美食/神仙粉",
"43": "美食/粉湯羊血",
"44": "美食/羊肉泡饃",
"45": "美食/肉夾饃",
"46": "美食/蕎麪餄餎",
"47": "美食/菠菜面",
"48": "美食/蜂蜜涼糉子",
"49": "美食/蜜餞張口酥餃",
"50": "美食/西安油茶",
"51": "美食/貴妃雞翅",
"52": "美食/醪糟",
"53": "美食/金線油塔"
}
def _preprocess(self, data):
preprocessed_data = {}
for k, v in data.items():
for file_name, file_content in v.items():
img = Image.open(file_content)
img = test_transforms(img)
preprocessed_data[k] = img
return preprocessed_data
def _inference(self, data):
"""
model inference function
Here are a inference example of resnet, if you use another model, please modify this function
"""
img = data[self.input_key_1]
img = img[np.newaxis, :, :, :] # the input tensor shape of resnet is [?, 224, 224, 3]
# pred_score = self.sess.run([self.output_score], feed_dict={self.input_images: img})
pred_score = self.model(img)
pred_score = pred_score.detach().numpy()
if pred_score is not None:
pred_label = np.argmax(pred_score, axis=1)[0]
pred_label = self.idx_to_cls[str(int(pred_label))]
result = {'result': self.label_id_name_dict[str(pred_label)]}
else:
result = {'result': 'predict score is None'}
return result
def _postprocess(self, data):
return data
“華爲雲杯”2019人工智能創新應用大賽:版本一customize_service.py文件
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.