pytorch模型轉tensorflow pb模型

#1.配置安裝

  • conda create -n pth_pb python=3.7

  • pip install tensorflow==2.1.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

  • pip install tensorflow-addons==0.9.1 -i https://pypi.tuna.tsinghua.edu.cn/simple

  • pip install onnx-tf==1.5.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

  • pip install onnx==1.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

  • conda install pytorch torchvision #(這裏我用的版本爲pytorch1.4, torchvision0.5.0)

  • 注意:如果你的python版本爲3.8,安裝onnx時會出現以下錯誤,降低python版本後解決

    問題:python setup.py egg_info Check the logs for full command output.

    解決:conda install python=3.7

    問題:No module named ‘pip._internal.cli.main’

    解決:easy_install pip

#2.代碼
pth轉onnx.py

import torchvision
import torch.onnx
import torch.nn as nn
def resnet50():
    model = torchvision.models.resnet50(pretrained=False)
    model.fc = nn.Linear(2048, 2)
    return model
model = resnet50()
model = torch.nn.DataParallel(model)
pthfile = r’my.pth'
loaded_model = torch.load(pthfile, map_location='cpu')
model.load_state_dict(loaded_model['state_dict'])
input = torch.randn(1, 3, 200, 200)
input_names = ["head_input"]
output_names = ["output"]
onnx_filename = “my.onnx"
torch.onnx.export(model.module, input, onnx_filename, verbose=True, input_names=input_names, output_names=output_names)

onnx轉pb.py

import torchvision
import torch.onnx
import torch.nn as nn
def resnet50():
    model = torchvision.models.resnet50(pretrained=False)
    model.fc = nn.Linear(2048, 2)
    return model
model = resnet50()
model = torch.nn.DataParallel(model)
pthfile = r'/Users/haobing1/myMac/model/net_epoch_100.pth'
loaded_model = torch.load(pthfile, map_location='cpu')
model.load_state_dict(loaded_model['state_dict'])
input = torch.randn(1, 3, 200, 200)
input_names = ["head_input"]
output_names = ["output"]
onnx_filename = "/Users/haobing1/myMac/model/1/resnet50_epoch_100.onnx"
torch.onnx.export(model.module, input, onnx_filename, verbose=True, input_names=input_names, output_names=output_names)

pb_predict.py

import tensorflow as tf
from torchvision import transforms
import numpy as np
from PIL import Image

transform = transforms.Compose([
    transforms.Resize((200, 200)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


with tf.Graph().as_default():
    # output_graph_def = tf.GraphDef() #tensorflow1.4版本
    output_graph_def = tf.compat.v1.GraphDef()
    output_graph_path = ‘my.pb'
    with open(output_graph_path, 'rb') as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")
    # with tf.Session() as sess:  #tensorflow1.4版本
    with tf.compat.v1.Session() as sess:
        image = “demo.jpg"
        image_np = Image.open(image)
        img_input = transform(image_np).unsqueeze(0)
        image_np_expanded = img_input.numpy()
        # sess.run(tf.global_variables_initializer())  #tensorflow1.4版本
        sess.run(tf.compat.v1.global_variables_initializer())
        input = sess.graph.get_tensor_by_name("head_input:0")
        output = sess.graph.get_tensor_by_name("output:0")
        predictions = sess.run(output, feed_dict={input: image_np_expanded})
        index = np.argmax(predictions)
        print("predictions:", predictions)
        print("index:", index)

  • 注意:其中onnx轉py中會出現警告,這裏並沒有評測警告的影響。
    使用第二種方法
    git clone https://github.com/onnx/onnx-tensorflow.git
    cd onnx-tensorflow
    pip install -e .
    onnx-tf convert -i /Users/haobing1/myMac/model/1/resnet50_epoch_100.onnx -o /Users/haobing1/myMac/model/1/res.pb
    無警告onnx轉換pb
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章