Datawhale零基礎入門CV賽事(街景字符編碼識別)-Task2 數據讀取與數據擴增

Datawhale 零基礎入門CV賽事-Task2 數據讀取與數據擴增

在上一章節,我們給大家講解了賽題的內容和三種不同的解決方案。從本章開始我們將逐漸的學習使用【定長字符識別】思路來構建模型,逐步講解賽題的解決方案和相應知識點。

2 數據讀取與數據擴增

本章主要內容爲數據讀取、數據擴增方法和Pytorch讀取賽題數據三個部分組成。

2.1 學習目標

  • 學習Python和Pytorch中圖像讀取
  • 學會擴增方法和Pytorch讀取賽題數據

2.2 圖像讀取

由於賽題數據是圖像數據,賽題的任務是識別圖像中的字符。因此我們首先需要完成對數據的讀取操作,在Python中有很多庫可以完成數據讀取的操作,比較常見的有Pillow和OpenCV。

2.2.1 Pillow

Pillow是Python圖像處理函式庫(PIL)的一個分支。Pillow提供了常見的圖像讀取和處理的操作,而且可以與ipython notebook無縫集成,是應用比較廣泛的庫。

效果 代碼
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-86IJE0AV-1590210258334)(IMG/Task02/Pillow讀取原圖.png)] from PIL import Image
# 導入Pillow庫

# 讀取圖片
im =Image.open(cat.jpg’)
在這裏插入圖片描述 from PIL import Image, ImageFilter
im = Image.open(‘cat.jpg’)
# 應用模糊濾鏡:
im2 = im.filter(ImageFilter.BLUR)
im2.save(‘blur.jpg’, ‘jpeg’)
上傳(img-iCKcDoEU-1590210258364)(IMG/Task02/Pillow縮放原圖.png)] from PIL import Image
# 打開一個jpg圖像文件,注意是當前路徑:
im = Image.open(‘cat.jpg’)
im.thumbnail((w//2, h//2))
im.save(‘thumbnail.jpg’, ‘jpeg’)

當然上面只演示了Pillow最基礎的操作,Pillow還有很多圖像操作,是圖像處理的必備庫。
Pillow的官方文檔:https://pillow.readthedocs.io/en/stable/

2.2.2 OpenCV

OpenCV是一個跨平臺的計算機視覺庫,最早由Intel開源得來。OpenCV發展的非常早,擁有衆多的計算機視覺、數字圖像處理和機器視覺等功能。OpenCV在功能上比Pillow更加強大很多,學習成本也高很多。

效果 代碼
在這裏插入圖片描述 import cv2
# 導入Opencv庫
img = cv2.imread(‘cat.jpg’)
# Opencv默認顏色通道順序是BRG,轉換一下
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-NngbmnQk-1590210258379)(IMG/Task02/opencv灰度圖.png)] import cv2
# 導入Opencv庫
img = cv2.imread(‘cat.jpg’)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 轉換爲灰度圖
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-X92KZkdw-1590210258381)(IMG/Task02/opencv邊緣檢測.png)] import cv2
# 導入Opencv庫
img = cv2.imread(‘cat.jpg’)
img =cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 轉換爲灰度圖
# Canny邊緣檢測
edges = cv2.Canny(img, 30, 70)
cv2.imwrite(‘canny.jpg’, edges)

OpenCV包含了衆多的圖像處理的功能,OpenCV包含了你能想得到的只要與圖像相關的操作。此外OpenCV還內置了很多的圖像特徵處理算法,如關鍵點檢測、邊緣檢測和直線檢測等。
OpenCV官網:https://opencv.org/
OpenCV Github:https://github.com/opencv/opencv
OpenCV 擴展算法庫:https://github.com/opencv/opencv_contrib

2.3 數據擴增方法

在上一小節中給大家初步介紹了Pillow和OpenCV的使用,現在回到賽題街道字符識別任務中。在賽題中我們需要對的圖像進行字符識別,因此需要我們完成的數據的讀取操作,同時也需要完成數據擴增(Data Augmentation)操作。

2.3.1 數據擴增介紹

在深度學習中數據擴增方法非常重要,數據擴增可以增加訓練集的樣本,同時也可以有效緩解模型過擬合的情況,也可以給模型帶來的更強的泛化能力。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-1czwGJer-1590210258384)(IMG/Task02/數據擴增.png)]

  • 數據擴增爲什麼有用?

在深度學習模型的訓練過程中,數據擴增是必不可少的環節。現有深度學習的參數非常多,一般的模型可訓練的參數量基本上都是萬到百萬級別,而訓練集樣本的數量很難有這麼多。
其次數據擴增可以擴展樣本空間,假設現在的分類模型需要對汽車進行分類,左邊的是汽車A,右邊爲汽車B。如果不使用任何數據擴增方法,深度學習模型會從汽車車頭的角度來進行判別,而不是汽車具體的區別。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-8szjBxJs-1590210258386)(IMG/Task02/數據擴增car.png)]

  • 有哪些數據擴增方法?

數據擴增方法有很多:從顏色空間、尺度空間到樣本空間,同時根據不同任務數據擴增都有相應的區別。
對於圖像分類,數據擴增一般不會改變標籤;對於物體檢測,數據擴增會改變物體座標位置;對於圖像分割,數據擴增會改變像素標籤。

2.3.2 常見的數據擴增方法

在常見的數據擴增方法中,一般會從圖像顏色、尺寸、形態、空間和像素等角度進行變換。當然不同的數據擴增方法可以自由進行組合,得到更加豐富的數據擴增方法。

以torchvision爲例,常見的數據擴增方法包括:

  • transforms.CenterCrop 對圖片中心進行裁剪
  • transforms.ColorJitter 對圖像顏色的對比度、飽和度和零度進行變換
  • transforms.FiveCrop 對圖像四個角和中心進行裁剪得到五分圖像
  • transforms.Grayscale 對圖像進行灰度變換
  • transforms.Pad 使用固定值進行像素填充
  • transforms.RandomAffine 隨機仿射變換
  • transforms.RandomCrop 隨機區域裁剪
  • transforms.RandomHorizontalFlip 隨機水平翻轉
  • transforms.RandomRotation 隨機旋轉
  • transforms.RandomVerticalFlip 隨機垂直翻轉

貓貓

在本次賽題中,賽題任務是需要對圖像中的字符進行識別,因此對於字符圖片並不能進行翻轉操作。比如字符6經過水平翻轉就變成了字符9,會改變字符原本的含義。

2.3.3 常用的數據擴增庫

  • torchvision

https://github.com/pytorch/vision
pytorch官方提供的數據擴增庫,提供了基本的數據數據擴增方法,可以無縫與torch進行集成;但數據擴增方法種類較少,且速度中等;

  • imgaug

https://github.com/aleju/imgaug
imgaug是常用的第三方數據擴增庫,提供了多樣的數據擴增方法,且組合起來非常方便,速度較快;

  • albumentations

https://albumentations.readthedocs.io
是常用的第三方數據擴增庫,提供了多樣的數據擴增方法,對圖像分類、語義分割、物體檢測和關鍵點檢測都支持,速度較快。

2.4 Pytorch讀取數據

由於本次賽題我們使用Pytorch框架講解具體的解決方案,接下來將是解決賽題的第一步使用Pytorch讀取賽題數據。
在Pytorch中數據是通過Dataset進行封裝,並通過DataLoder進行並行讀取。所以我們只需要重載一下數據讀取的邏輯就可以完成數據的讀取。

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中類別10爲數字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

data = SVHNDataset(train_path, train_label,
          transforms.Compose([
              # 縮放到固定尺寸
              transforms.Resize((64, 128)),

              # 隨機顏色變換
              transforms.ColorJitter(0.2, 0.2, 0.2),

              # 加入隨機旋轉
              transforms.RandomRotation(5),

              # 將圖片轉換爲pytorch 的tesntor
              # transforms.ToTensor(),

              # 對圖像像素進行歸一化
              # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ]))

通過上述代碼,可以將賽題的圖像數據和對應標籤進行讀取,在讀取過程中的進行數據擴增,效果如下所示:
data augmentation

接下來我們將在定義好的Dataset基礎上構建DataLoder,你可以會問有了Dataset爲什麼還要有DataLoder?其實這兩個是兩個不同的概念,是爲了實現不同的功能。

  • Dataset:對數據集的封裝,提供索引方式的對數據樣本進行讀取
  • DataLoder:對Dataset進行封裝,提供批量讀取的迭代讀取

加入DataLoder後,數據讀取代碼改爲如下:

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中類別10爲數字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批樣本個數
    shuffle=False, # 是否打亂順序
    num_workers=10, # 讀取的線程個數
)

for data in train_loader:
    break

在加入DataLoder後,數據按照批次獲取,每批次調用Dataset讀取單個樣本進行拼接。此時data的格式爲:
torch.Size([10, 3, 64, 128]), torch.Size([10, 6])
前者爲圖像文件,爲batchsize * chanel * height * width次序;後者爲字符標籤。

2.5 本章小節

本章對數據讀取進行了詳細的講解,並介紹了常見的數據擴增方法和使用,最後使用Pytorch框架對本次賽題的數據進行讀取。

對於baseline的理解:

代碼分析:

  • 加載訓練集(*.png)及其label
  • 用dataloader封裝數據集
    baseline的結果(CPU):
    在這裏插入圖片描述

數據擴增:

  • 爲什麼要進行數據擴增?

在深度學習中,一般要求樣本的數量要充足,樣本數量越多,這樣訓練出來的模型效果越好,模型的泛化能力越強。但是實際中,樣本數量不足或者樣本質量不夠好,這就要對樣本做數據增強,來提高樣本質量。我們可以將原來的所有圖片進行隨機改變形式,指定尺寸,隨機裁剪,調整亮度,隨機旋轉角度等自定義操作,得到不同的數據集。

  • 數據增強的作用:

1,增加訓練的數據量,提高模型的泛化能力
2,增加噪聲數據,提升模型的魯棒性

在baseline中使用pytorch進行數據增強:

 train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                    transforms.Compose([
                        transforms.Resize((64, 128)),
                        transforms.RandomCrop((60, 120)),
                        transforms.ColorJitter(0.3, 0.3, 0.2),
                        transforms.RandomRotation(10),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                    ])),
        batch_size=256,
        shuffle=True,
        num_workers=0,
    )

pytorch讀取數據

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章