街景字符編碼識別之數據讀取與擴增

點贊再看,養成習慣!

前言

繼上一節對數據進行極其簡單的數據分析後,這一節開始做數據加載,目標就是組織好數據,可以以一種正確的姿勢餵給後續的模型。不同的深度學習框架,數據加載這一塊是有所不同的,這裏講解的是PyTorch的數據處理工具。

正文

圖像讀取

這裏主要介紹兩個常用的庫:
Pillow【輕量級】
Pillow是Python圖像處理函式庫(PIL)的一個分支。Pillow提供了常見的圖像讀取和處理的操作,而且可以與ipython notebook無縫集成,是應用比較廣泛的庫。

from PIL import Image

# 圖像讀取
im =Image.open(path)

OpenCV【重量級】
OpenCV是一個跨平臺的計算機視覺庫,最早由Intel開源得來。OpenCV發展的非常早,擁有衆多的計算機視覺、數字圖像處理和機器視覺等功能。OpenCV在功能上比Pillow更加強大很多,學習成本也高很多。

import cv2

# 圖像讀取
img = cv2.imread('cat.jpg')
# Opencv默認顏色通道順序是BGR,轉換一下
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

【小編友情提醒】
雖然python程序在使用opencv是導入cv2,但是真正用conda或者pip下載的庫的名字叫opencv-python,這點要格外注意!

數據擴增

在深度學習中數據擴增方法非常重要,數據擴增可以增加訓練集的樣本,同時也可以有效緩解模型過擬合的情況,也可以給模型帶來的更強的泛化能力。這裏是針對圖像數據進行擴增,所以常見的角度有圖像顏色、尺寸、形態、空間和像素等。其實小編以前常見常用的也只有圖像顏色變化、翻轉、裁剪這三種操作。不過這裏字符不可以進行翻轉,例如6倒過來會變成9,改變字符原先的含義。
常見的庫

  1. torchvision
    pytorch官方提供的數據擴增庫,提供了基本的數據數據擴增方法,可以無縫與torch進行集成;但數據擴增方法種類較少,且速度中等;
    常用方法:
    transforms.RandomCrop 隨機區域裁剪
    transforms.ColorJitter 對圖像顏色的對比度、飽和度和零度進行變換
    transforms.Grayscale 對圖像進行灰度變換
    transforms.Pad 使用固定值進行像素填充
    transforms.RandomRotation 隨機旋轉
SVHNDataset(train_path, train_label,
                    transforms.Compose([
                        transforms.Resize((64, 128)), 
                        transforms.ColorJitter(0.3, 0.3, 0.2), #顏色變化
                        transforms.RandomRotation(5), #隨機旋轉,不能旋轉太多
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
                    ]))
  1. imgaug
    imgaug是常用的第三方數據擴增庫,提供了多樣的數據擴增方法,且組合起來非常方便,速度較快;
  2. albumentations
    是常用的第三方數據擴增庫,提供了多樣的數據擴增方法,對圖像分類、語義分割、物體檢測和關鍵點檢測都支持,速度較快。

圖像擴增示例效果圖:
圖像擴增

PyTorch數據加載

PyTorch數據加載的過程是:數據集本身要轉化成Dataset實例,而提供給模型訓練、驗證或測試時的讀取要用DataLoader實例。

  • Dataset:對數據集的封裝,提供索引方式的對數據樣本進行讀取
  • DataLoader:對Dataset進行封裝,提供批量讀取的迭代讀取,可以用多進程加速

實施流程:

  1. 繼承Dataset類,並實現__init__、getitem、__len__等函數成員,這裏類名爲SVHNDataset。
class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path #所有圖像數據路徑
        self.img_label = img_label #所有圖像標籤數據
        if transform is not None:
            self.transform = transform #預處理流
        else:
            self.transform = None

    def __getitem__(self, index):
        # just handle one data
        img = Image.open(self.img_path[index]).convert('RGB') #讀取圖像

        if self.transform is not None:
            img = self.transform(img) #預處理

        # 定長字符識別策略,填充的字符爲10,這樣不會與有效字符0-9發生碰撞
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl) + (5 - len(lbl)) * [10]

        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path) #數據集大小
  1. DataLoader加載SVHNDataset
train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                    transforms.Compose([
                        transforms.Resize((64, 128)),
                        transforms.ColorJitter(0.3, 0.3, 0.2),
                        transforms.RandomRotation(5),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                    ])),
        batch_size=10,  # 每批樣本個數
        shuffle=False,  # 是否打亂順序
        num_workers=5,  #進程個數
    )

結語

PyTorch數據加載的流程較爲固定,但因爲Dataset能夠自定義,所以數據讀取就比較靈活。值得說一句的是,數據預處理的數據擴增並不是說直接擴增數據,比如把3W的訓練集擴增到更多,而是在深度學習的訓練過程中把每張圖片都通過transform處理流進行變化,這樣不同的迭代中同一索引的圖像都不一定相同,從而達到了數據擴增的目標。

參考文獻

  1. Pillow的官方文檔:https://pillow.readthedocs.io/en/stable/
  2. OpenCV官網:https://opencv.org/
    OpenCV Github:https://github.com/opencv/opencv
    OpenCV 擴展算法庫:https://github.com/opencv/opencv_contrib
  3. torchvision: https://github.com/pytorch/vision
  4. imgaug: https://github.com/aleju/imgaug
  5. albumentations: https://albumentations.readthedocs.io

童鞋們,讓小編聽見你們的聲音,點贊評論,一起加油。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章