點贊再看,養成習慣!
前言
繼上一節對數據進行極其簡單的數據分析後,這一節開始做數據加載,目標就是組織好數據,可以以一種正確的姿勢餵給後續的模型。不同的深度學習框架,數據加載這一塊是有所不同的,這裏講解的是PyTorch的數據處理工具。
正文
圖像讀取
這裏主要介紹兩個常用的庫:
Pillow【輕量級】
Pillow是Python圖像處理函式庫(PIL)的一個分支。Pillow提供了常見的圖像讀取和處理的操作,而且可以與ipython notebook無縫集成,是應用比較廣泛的庫。
from PIL import Image
# 圖像讀取
im =Image.open(path)
OpenCV【重量級】
OpenCV是一個跨平臺的計算機視覺庫,最早由Intel開源得來。OpenCV發展的非常早,擁有衆多的計算機視覺、數字圖像處理和機器視覺等功能。OpenCV在功能上比Pillow更加強大很多,學習成本也高很多。
import cv2
# 圖像讀取
img = cv2.imread('cat.jpg')
# Opencv默認顏色通道順序是BGR,轉換一下
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
【小編友情提醒】
雖然python程序在使用opencv是導入cv2,但是真正用conda或者pip下載的庫的名字叫opencv-python,這點要格外注意!
數據擴增
在深度學習中數據擴增方法非常重要,數據擴增可以增加訓練集的樣本,同時也可以有效緩解模型過擬合的情況,也可以給模型帶來的更強的泛化能力。這裏是針對圖像數據進行擴增,所以常見的角度有圖像顏色、尺寸、形態、空間和像素等。其實小編以前常見常用的也只有圖像顏色變化、翻轉、裁剪這三種操作。不過這裏字符不可以進行翻轉,例如6倒過來會變成9,改變字符原先的含義。
常見的庫
- torchvision
pytorch官方提供的數據擴增庫,提供了基本的數據數據擴增方法,可以無縫與torch進行集成;但數據擴增方法種類較少,且速度中等;
常用方法:
transforms.RandomCrop 隨機區域裁剪
transforms.ColorJitter 對圖像顏色的對比度、飽和度和零度進行變換
transforms.Grayscale 對圖像進行灰度變換
transforms.Pad 使用固定值進行像素填充
transforms.RandomRotation 隨機旋轉
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2), #顏色變化
transforms.RandomRotation(5), #隨機旋轉,不能旋轉太多
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
- imgaug
imgaug是常用的第三方數據擴增庫,提供了多樣的數據擴增方法,且組合起來非常方便,速度較快; - albumentations
是常用的第三方數據擴增庫,提供了多樣的數據擴增方法,對圖像分類、語義分割、物體檢測和關鍵點檢測都支持,速度較快。
圖像擴增示例效果圖:
PyTorch數據加載
PyTorch數據加載的過程是:數據集本身要轉化成Dataset實例,而提供給模型訓練、驗證或測試時的讀取要用DataLoader實例。
- Dataset:對數據集的封裝,提供索引方式的對數據樣本進行讀取
- DataLoader:對Dataset進行封裝,提供批量讀取的迭代讀取,可以用多進程加速
實施流程:
- 繼承Dataset類,並實現__init__、getitem、__len__等函數成員,這裏類名爲SVHNDataset。
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path #所有圖像數據路徑
self.img_label = img_label #所有圖像標籤數據
if transform is not None:
self.transform = transform #預處理流
else:
self.transform = None
def __getitem__(self, index):
# just handle one data
img = Image.open(self.img_path[index]).convert('RGB') #讀取圖像
if self.transform is not None:
img = self.transform(img) #預處理
# 定長字符識別策略,填充的字符爲10,這樣不會與有效字符0-9發生碰撞
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path) #數據集大小
- DataLoader加載SVHNDataset
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size=10, # 每批樣本個數
shuffle=False, # 是否打亂順序
num_workers=5, #進程個數
)
結語
PyTorch數據加載的流程較爲固定,但因爲Dataset能夠自定義,所以數據讀取就比較靈活。值得說一句的是,數據預處理的數據擴增並不是說直接擴增數據,比如把3W的訓練集擴增到更多,而是在深度學習的訓練過程中把每張圖片都通過transform處理流進行變化,這樣不同的迭代中同一索引的圖像都不一定相同,從而達到了數據擴增的目標。
參考文獻
- Pillow的官方文檔:https://pillow.readthedocs.io/en/stable/
- OpenCV官網:https://opencv.org/
OpenCV Github:https://github.com/opencv/opencv
OpenCV 擴展算法庫:https://github.com/opencv/opencv_contrib - torchvision: https://github.com/pytorch/vision
- imgaug: https://github.com/aleju/imgaug
- albumentations: https://albumentations.readthedocs.io
童鞋們,讓小編聽見你們的聲音,點贊評論,一起加油。