#1.配置安裝
-
conda create -n pth_pb python=3.7
-
pip install tensorflow==2.1.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
-
pip install tensorflow-addons==0.9.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
-
pip install onnx-tf==1.5.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
-
pip install onnx==1.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
-
conda install pytorch torchvision #(這裏我用的版本爲pytorch1.4, torchvision0.5.0)
-
注意:如果你的python版本爲3.8,安裝onnx時會出現以下錯誤,降低python版本後解決
問題:python setup.py egg_info Check the logs for full command output.
解決:conda install python=3.7
問題:No module named ‘pip._internal.cli.main’
解決:easy_install pip
#2.代碼
pth轉onnx.py
import torchvision
import torch.onnx
import torch.nn as nn
def resnet50():
model = torchvision.models.resnet50(pretrained=False)
model.fc = nn.Linear(2048, 2)
return model
model = resnet50()
model = torch.nn.DataParallel(model)
pthfile = r’my.pth'
loaded_model = torch.load(pthfile, map_location='cpu')
model.load_state_dict(loaded_model['state_dict'])
input = torch.randn(1, 3, 200, 200)
input_names = ["head_input"]
output_names = ["output"]
onnx_filename = “my.onnx"
torch.onnx.export(model.module, input, onnx_filename, verbose=True, input_names=input_names, output_names=output_names)
onnx轉pb.py
import torchvision
import torch.onnx
import torch.nn as nn
def resnet50():
model = torchvision.models.resnet50(pretrained=False)
model.fc = nn.Linear(2048, 2)
return model
model = resnet50()
model = torch.nn.DataParallel(model)
pthfile = r'/Users/haobing1/myMac/model/net_epoch_100.pth'
loaded_model = torch.load(pthfile, map_location='cpu')
model.load_state_dict(loaded_model['state_dict'])
input = torch.randn(1, 3, 200, 200)
input_names = ["head_input"]
output_names = ["output"]
onnx_filename = "/Users/haobing1/myMac/model/1/resnet50_epoch_100.onnx"
torch.onnx.export(model.module, input, onnx_filename, verbose=True, input_names=input_names, output_names=output_names)
pb_predict.py
import tensorflow as tf
from torchvision import transforms
import numpy as np
from PIL import Image
transform = transforms.Compose([
transforms.Resize((200, 200)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
with tf.Graph().as_default():
# output_graph_def = tf.GraphDef() #tensorflow1.4版本
output_graph_def = tf.compat.v1.GraphDef()
output_graph_path = ‘my.pb'
with open(output_graph_path, 'rb') as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
# with tf.Session() as sess: #tensorflow1.4版本
with tf.compat.v1.Session() as sess:
image = “demo.jpg"
image_np = Image.open(image)
img_input = transform(image_np).unsqueeze(0)
image_np_expanded = img_input.numpy()
# sess.run(tf.global_variables_initializer()) #tensorflow1.4版本
sess.run(tf.compat.v1.global_variables_initializer())
input = sess.graph.get_tensor_by_name("head_input:0")
output = sess.graph.get_tensor_by_name("output:0")
predictions = sess.run(output, feed_dict={input: image_np_expanded})
index = np.argmax(predictions)
print("predictions:", predictions)
print("index:", index)
- 注意:其中onnx轉py中會出現警告,這裏並沒有評測警告的影響。
使用第二種方法
git clone https://github.com/onnx/onnx-tensorflow.git
cd onnx-tensorflow
pip install -e .
onnx-tf convert -i /Users/haobing1/myMac/model/1/resnet50_epoch_100.onnx -o /Users/haobing1/myMac/model/1/res.pb
無警告onnx轉換pb