將ResNet分類器做成一個小網站界面

在上一節寫到的 Pytorch完整訓練自己的數據集 ,我們可以將訓練好的模型和演示代碼寫入網站中,方便演示。還是以pokeman數據分類爲例子。

因爲整個用到的是python代碼,所以構建網頁我們也是採用python語言,用的框架是:Flask

1、首先寫分類demo代碼

這一步先要把訓練好的加載,直接去預測未見過的圖片。代碼命名爲demp.py如下:

import torch
from torch import optim, nn
from torchvision import transforms
from torchvision.models import resnet18
from utils import Flatten, softmax
from PIL import Image
import os
import numpy as np

def predicts(img):
    device = torch.device('cuda')
    torch.manual_seed(1234)
    resize = 224
    className = {
        '0': 'bulbasaur',
        '1': 'charmander',
        '2': 'mewtw',
        '3': 'pikachu',
        '4': 'squirtle'}

    # model = ResNet18(5).to(device)
    trained_model = resnet18(pretrained=True)
    model = nn.Sequential(*list(trained_model.children())[:-1],  # [b, 512, 1, 1]
                          Flatten(),  # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 5)
                          ).to(device)
    # x = torch.randn(2, 3, 224, 224)
    # print(model(x).shape)
    basepath = os.path.dirname(__file__)  # 當前文件所在路徑
    ckpt_path = os.path.join(basepath, 'best.mdl')
    print(ckpt_path)
    model.load_state_dict(torch.load(ckpt_path))
    print('loaded from ckpt!')
    tf = transforms.Compose([
        lambda x: Image.open(x).convert('RGB'),  # string path= > image data
        transforms.Resize(
            (int(resize), int(resize))),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    # img = 'pokeman\\pikachu\\00000003.jpg'
    x = tf(img)
    x = x.unsqueeze(0)
    x = x.to(device)
    
    model.eval()
    with torch.no_grad(): 
        logits = model(x)
        pred = logits.argmax(dim=1).item()
        prob = np.max(softmax(logits.cpu().numpy()), axis=1)[0]
    # print('Our model predicts : %s'%className[str(pred)])
    return className[str(pred)], str(round(prob * 100, 2)) + '%'
    
if __name__ == '__main__':
    img = r'C:\spyder\imgshow\static\images\00000009.png'
    pres = predicts(img)

需要注意的是由於不需要訓練,只是測試,需要添加:model.eval(),同時我們不需要求導求梯度,因此在模型運算的前面加上with torch.no_grad():

上面函數返回了預測的類別,以及置信度,這個需要顯示在網頁上面

2、構建一個簡單的網站

這一步我們採用Flask寫一個很簡單的網站,代碼命名爲resnet_class.py:

# coding:utf-8
from flask import Flask, render_template, request, redirect, url_for, make_response, jsonify
from werkzeug.utils import secure_filename
import os
import cv2
import time
from demo import predicts
from datetime import timedelta

# 設置允許的文件格式
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'bmp'])

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS

app = Flask(__name__)
# 設置靜態文件緩存過期時間
app.send_file_max_age_default = timedelta(seconds=1)

# @app.route('/resnet', methods=['POST', 'GET'])
@app.route('/', methods=['POST', 'GET'])  # 添加路由
def upload():
    if request.method == 'POST':
        f = request.files['file']

        if not (f and allowed_file(f.filename)):
            return jsonify({"error": 1001,
                            "msg": "請檢查上傳的圖片類型,僅限於png、PNG、jpg、JPG、bmp"})

        user_input = request.form.get("name")
        basepath = os.path.dirname(__file__)  # 當前文件所在路徑
        upload_path = os.path.join(
            basepath,
            'static/images',
            secure_filename(
                f.filename))  # 注意:沒有的文件夾一定要先創建,不然會提示沒有該路徑
        # upload_path = os.path.join(basepath, 'static/images','test.jpg')
        # #注意:沒有的文件夾一定要先創建,不然會提示沒有該路徑
        f.save(upload_path)

        # 使用Opencv轉換一下圖片格式和名稱
        img = cv2.imread(upload_path)
        cv2.imwrite(os.path.join(basepath, 'static/images', 'test.jpg'), img)
        pres, pro = predicts(upload_path)
        print(upload_path)

        return render_template(
					            'upload_ok.html',
					            userinput=user_input,
					            classresult=pres,
					            classpro=pro,
					            val1=time.time())
    return render_template('upload.html')

if __name__ == '__main__':
    # app.debug = True
    app.run(host='127.0.0.1', port=5000, debug=True)

說明:
1、在運行該代碼的時候,需要在終端運行:

set FLASK_APP=resnet_class.py
flask run

便可以運行該代碼。

2、下面這句話是個裝飾器,可以看之前的寫什麼是裝飾器:https://blog.csdn.net/lifei1229/article/details/105757933

@app.route('/', methods=['POST', 'GET'])  # 添加路由

route(’/’) 中傳入’/'表示根目錄,即在輸入網站不需要加上後面目錄:

如果我們寫成了

@app.route(’/resnet’, methods=[‘POST’, ‘GET’])

那麼網站後面需要添加resnet:

3、寫入網頁

這一步就是構建上傳圖片,顯示圖片,顯示分類結果的網頁了。
下面是原始網頁upload.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>使用ResNet分類圖像演示平臺</title>
</head>
<body>
    <h1>使用ResNet分類圖像演示平臺</h1>
    <form action="" enctype='multipart/form-data' method='POST'>
        <input type="file" name="file" style="margin-top:20px;"/>
        <br>
        <h2>請輸入你認爲這張圖片的分類標籤:</h2>
        <input type="text" class="txt_input" name="name"  value="請輸入pokeman類別" style="margin-top:10px;"/>
        <input type="submit" value="上傳" class="button-new" style="margin-top:15px;"/>
    </form>
</body>
</html>

如果圖片上傳成功,則會用到下面這個網頁,命名爲 upload_ok.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>使用ResNet分類圖像演示平臺</title>
</head>
<body>
    <h1>使用ResNet分類圖像演示平臺</h1>
    <form action="" enctype='multipart/form-data' method='POST'>
        <input type="file" name="file" style="margin-top:20px;"/>
        <br>
        <h2>請輸入你認爲這張圖片的分類標籤:</h2>
        <input type="text" class="txt_input" name="name"  value="請輸入pokeman類別" style="margin-top:10px;"/>
        <input type="submit" value="上傳" class="button-new" style="margin-top:15px;"/>
    </form>
    <h2>閣下認爲這張照片是:{{userinput}}</h2>
    <img src="{{ url_for('static', filename= './images/test.jpg',_t=val1) }}" width="400" height="400" alt="你的圖片被外星人劫持了~~"/>
    <h2>我們使用ResNet模型預測,有{{classpro}}概率認爲它是 {{classresult}}</h2>
</body>
</html>

4、整個文件的層級結構

由於網頁和python代碼交互需要用到

from flask import Flask, render_template,

且文件的存放位置也有要求,在本例子中:
在這裏插入圖片描述
templates是存放網頁代碼的文件夾

5、演示結果

再看一個例子:

6、總結

將網站代碼放在服務器上,且外網能訪問,便可以向別人演示你的深度學習模型的效果了,不只是分類,圖像去噪,增強,檢測,分割都可以弄一個簡單的網站,向別人展示你的優秀的模型效果。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章