26圖像分類

一、圖像分類

1.1 模型是如何將圖像分類的?

在這裏插入圖片描述
對於蜜蜂螞蟻二分類模型:
從人的角度來看,是從輸入一張RGB圖像到輸出一種動物的過程
從計算機角度看,是從輸入3-d張量到輸出字符串的過程

在這裏插入圖片描述
類別名是通過標籤進行轉換得到的,在這裏也就是0和1,而輸出的0,1則是通過模型輸出的向量取最大值而得到的,而模型輸出向量則是通過構造複雜的模型而得到的

實際的運行順序:
輸入3d張量到模型中,模型經過複雜的數學運算,輸出一個向量,這個向量就是模型的輸出,然後再對輸出的向量取最大值和標籤與類別名的轉換,最後纔得到最終的字符串的輸出

在這裏插入圖片描述

1.2 圖像分類的Inference(推理)

圖像分類的Inference(推理)步驟:

  1. 獲取數據與標籤
  2. 選擇模型,損失函數,優化器
  3. 寫訓練代碼
  4. 寫inference代碼

Inference代碼基本步驟:

  1. 獲取數據與模型
  2. 數據變換,如RGB → 4D-Tensor
  3. 前向傳播
  4. 輸出保存預測結果

Inference階段注意事項:

  1. 確保 model處於eval狀態而非training
  2. 設置torch.no_grad(),減少內存消耗
  3. 數據預處理需保持一致, RGB or BGR?

二、resnet18模型inference代碼

# -*- coding: utf-8 -*-

import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# config
vis = True
# vis = False
vis_row = 4

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

inference_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

classes = ["ants", "bees"]


def img_transform(img_rgb, transform=None):
    """
    將數據轉換爲模型讀取的形式
    :param img_rgb: PIL Image
    :param transform: torchvision.transform
    :return: tensor
    """

    if transform is None:
        raise ValueError("找不到transform!必須有transform對img進行處理")

    img_t = transform(img_rgb)             # 將rgb圖像轉化爲tensor
    return img_t


def get_img_name(img_dir, format="jpg"):
    """
    獲取文件夾下format格式的文件名
    :param img_dir: str
    :param format: str
    :return: list
    """
    file_names = os.listdir(img_dir)
    img_names = list(filter(lambda x: x.endswith(format), file_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式數據".format(img_dir, format))
    return img_names


def get_model(m_path, vis_model=False):

    resnet18 = models.resnet18()
    num_ftrs = resnet18.fc.in_features
    resnet18.fc = nn.Linear(num_ftrs, 2)

    checkpoint = torch.load(m_path)
    resnet18.load_state_dict(checkpoint['model_state_dict'])

    if vis_model:
        from torchsummary import summary
        summary(resnet18, input_size=(3, 224, 224), device="cpu")

    return resnet18


if __name__ == "__main__":

    img_dir = os.path.join("..", "..", "data/hymenoptera_data/val/bees")
    model_path = "./checkpoint_14_epoch.pkl"
    time_total = 0
    img_list, img_pred = list(), list()

    # 1. data
    img_names = get_img_name(img_dir)
    num_img = len(img_names)

    # 2. model
    resnet18 = get_model(model_path, True)
    resnet18.to(device)                       # 將模型遷移到指定設備上
    resnet18.eval()                           # 通過eval(),指明模型不是在訓練狀態

    with torch.no_grad():                     # torch.no_grad()告訴pytorch,下面所有計算不計算梯度
        for idx, img_name in enumerate(img_names):

            path_img = os.path.join(img_dir, img_name)

            # step 1/4 : path --> img
            img_rgb = Image.open(path_img).convert('RGB')

            # step 2/4 : img --> tensor
            img_tensor = img_transform(img_rgb, inference_transform)
            img_tensor.unsqueeze_(0)
            img_tensor = img_tensor.to(device)

            # step 3/4 : tensor --> vector
            time_tic = time.time()
            outputs = resnet18(img_tensor)
            time_toc = time.time()

            # step 4/4 : visualization
            _, pred_int = torch.max(outputs.data, 1)
            pred_str = classes[int(pred_int)]

            if vis:
                img_list.append(img_rgb)
                img_pred.append(pred_str)

                if (idx+1) % (vis_row*vis_row) == 0 or num_img == idx+1:
                    for i in range(len(img_list)):
                        plt.subplot(vis_row, vis_row, i+1).imshow(img_list[i])
                        plt.title("predict:{}".format(img_pred[i]))
                    plt.show()
                    plt.close()
                    img_list, img_pred = list(), list()

            time_s = time_toc-time_tic
            time_total += time_s

            print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))

    print("\ndevice:{} total time:{:.1f}s mean:{:.3f}s".
          format(device, time_total, time_total/num_img))
    if torch.cuda.is_available():
        print("GPU name:{}".format(torch.cuda.get_device_name()))


運行結果:
在這裏插入圖片描述
預測結果:
在這裏插入圖片描述

三、resnet18結構分析

在這裏插入圖片描述
經典的卷積神經網絡:alexnet, vgg, googlenet, resnet, densenet
輕量化卷積神經網絡:mobilenet, shufflenet, squeezenet
自動搜索結構網絡:mnasenet
在這裏插入圖片描述
在這裏插入圖片描述

發佈了111 篇原創文章 · 獲贊 9 · 訪問量 9134
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章