mxnet學習(7):數據載入方式

1.使用ImageRecordIter讀取rec

mxnet.io.ImageRecordIter(*args, **kwargs)

該方式只能從rec文件讀取batches,相比於定製化的輸入方式,該方式不夠靈活,但是速度很快。如果要讀取原圖,可以使用ImageIter

eg:

data_iter = mx.io.ImageRecordIter(
  path_imgrec="./sample.rec", # The target record file.
  data_shape=(3, 227, 227), # Output data shape; 227x227 region will be cropped from the original image.
  batch_size=4, # Number of items per batch.
  resize=256 # Resize the shorter edge to 256 before cropping.
  # You can specify more augmentation options. Use help(mx.io.ImageRecordIter) to see all the options.
  )
# You can now use the data_iter to access batches of images.
batch = data_iter.next() # first batch.
images = batch.data[0] # This will contain 4 (=batch_size) images each of 3x227x227.
# process the images
...
data_iter.reset() # To restart the iterator from the beginning.

參數中可以指定augmentation的各種操作具體的參數可以參考

http://mxnet.incubator.apache.org/versions/master/api/python/io/io.html?highlight=record

1.mxnet.image.ImageIter讀取rec或者原圖

class mxnet.image.ImageIter(
                            batch_size,
                            data_shape, #只支持3通道RGB
                            label_width=1, 
                            path_imgrec=None,
                            path_imglist=None, 
                            path_root=None, 
                            path_imgidx=None, 
                            shuffle=False, 
                            part_index=0, 
                            num_parts=1, 
                            aug_list=None, 
                            imglist=None, 
                            data_name ='data', 
                            label_name ='softmax_label', 
                            dtype='float32', 
                            last_batch_handle='pad', 
                            **kwargs
                            )

這是一個帶有大量augmentation操作的data iterator,它支持從.rec文件或者原始圖片讀取數據

使用path_imgrec參數load .rec文件,使用path_imglist參數load原始圖片數據。

通過指定path_imgidx參數使用數據分佈式訓練或者shuffling

參考

http://mxnet.incubator.apache.org/versions/master/api/python/image/image.html#mxnet.image.ImageIter
https://blog.csdn.net/u014380165/article/details/74906061

一個使用的例子

data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227),
                              path_imgrec="./data/caltech.rec",
                              path_imgidx="./data/caltech.idx" )

# data_iter的類型是mxnet.image.ImageIter
#reset()函數的作用是:resents the iterator to the beginning of the data
data_iter.reset()

#batch的類型是mxnet.io.DataBatch,因爲next()方法的返回值就是DataBatch
batch = data_iter.next()

#data是一個NDArray,表示第一個batch中的數據,因爲這裏的batch_size大小是4,所以data的size是4*3*227*227
data = batch.data[0]

#這個for循環就是讀取這個batch中的每張圖像並顯示
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()

使用mx.image.CreateAugmenter()進行圖像augmentation

train = mx.image.ImageIter(
        batch_size            = args.batch_size,
        data_shape          = (3,224,224),
        label_width           = 1,
        path_imglist          = args.data_train,
        path_root              = args.image_train,
        part_index            = rank,
        shuffle                  = True,
        data_name           = 'data',
        label_name           = 'softmax_label',
        aug_list                 = mx.image.CreateAugmenter((3,224,224),resize=224,rand_crop=True,rand_mirror=True,mean=True))

image.CreateAugmenter相關的設置和參數

image.CreateAugmenter(
                data_shape,
                resize=0,
                rand_crop=False,
                rand_resize=False,
                rand_mirror=False,
                mean=None,#這裏如果是True,默認imagenet的均值
                std=None,#同上
                brightness=0,
                contrast=0,
                saturation=0,
                hue=0,
                pca_noise=0,
                rand_gray=0,
                inter_method=2
                )
#Creates an augmenter list.

Parameters:

  • data_shape (tuple of int) – Shape for output data
  • resize (int) – Resize shorter edge if larger than 0 at the begining
  • rand_crop (bool) – Whether to enable random cropping other than center crop
  • rand_resize (bool) – Whether to enable random sized cropping, require rand_crop to be enabled
  • rand_gray (float) – [0, 1], probability to convert to grayscale for all channels, the number of channels will not be reduced to 1
  • rand_mirror (bool) – Whether to apply horizontal flip to image with probability 0.5
  • mean (np.ndarray or None) – Mean pixel values for [r, g, b]
  • std (np.ndarray or None) – Standard deviations for [r, g, b]
  • brightness (float) – Brightness jittering range (percent)
  • contrast (float) – Contrast jittering range (percent)
  • saturation (float) – Saturation jittering range (percent)
  • hue (float) – Hue jittering range (percent)
  • pca_noise (float) – Pca noise level (percent)
  • inter_method (int, default=2(Area-based)) –
    Interpolation method for all resizing operations
    Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 3: Bicubic interpolation over 4x4 pixel neighborhood. 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK).

3.使用Dataset和DataLoader

gluon中提供了一種使用dataset和DataLoader載入數據的方式,這種載入數據方式與pytorch十分相似。

參考:https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/datasets.html

Dataset對象用於表示數據集合以及載入和解析數據的方法。gluon中有許多不同的Dataset類,下面用ArrayDataset進行示範說明。

import mxnet as mx
import os
import tarfile

mx.random.seed(42) # Fix the seed for reproducibility
X = mx.random.uniform(shape=(10, 3))
y = mx.random.uniform(shape=(10, 1))
dataset = mx.gluon.data.dataset.ArrayDataset(X, y)

Dataset的重要特點之一就是可以根據一個index檢索到對應的sample。

sample_idx = 4
sample = dataset[sample_idx]

assert len(sample) == 2
assert sample[0].shape == (3, )#data
assert sample[1].shape == (1, )#label
print(sample)

但是我們通常不會直接使用索引對Dataset進行檢索,而是使用DataLoader

DataLoader被用來從Dataset中建立一個mini-batch,並提供一個方便的迭代器接口,作爲batch的循環。其重要的參數是batch_size

DataLoader另外一個優點是可以使用多線程來載入數據,參數num_workers

from multiprocessing import cpu_count
CPU_COUNT = cpu_count()

data_loader = mx.gluon.data.DataLoader(dataset, batch_size=5, num_workers=CPU_COUNT)

for X_batch, y_batch in data_loader:
    print("X_batch has shape {}, and y_batch has shape {}".format(X_batch.shape, y_batch.shape))

當datset中的所有樣本都做爲batch的一個樣本返回之後,loader的循環就會停止。有時候dataset中的樣本數不能被batch_size整除,默認情況下是最後一個循環返回一個比batch_size小的batch,也可以指定last_batch參數爲discard(忽略最後一個batch),或者rollover(下一個epoch從剩餘的samples開始)

使用Dataset加載自定義數據

gluon中有許多的Dataset類,其中mxnet.gluon.data.vision.datasets.ImageFolderDatset直接從用戶定義的文件夾中加載數據,並且推斷其label(class)。

使用該類必須將不同label的圖片放在不同的文件夾下面

train_dataset = mx.gluon.data.vision.datasets.ImageFolderDataset(training_path)
test_dataset = mx.gluon.data.vision.datasets.ImageFolderDataset(testing_path)

有一個直接讀取rec的Dataset類

class mxnet.gluon.data.vision.datasets.ImageRecordDataset(filename, flag=1, transform=None)

A dataset wrapping over a RecordIO file containing images.
Each sample is an image and its corresponding label.

Parameters:

  • filename (str) – Path to rec file.
  • flag ({0, 1}, default 1) – If 0, always convert images to greyscale. If 1, always convert images to colored (RGB).
  • transform (function, default None) –
    A user defined callback that transforms each sample. For example:
transform=lambda data, label: (data.astype(np.float32)/255, label)

此外也可以通過自定義Dataset類的方式來載入數據,如下節。

4.自定義Dataset載入方式

官方文檔中提供了一種定義custom dataset的數據載入方式,這種方式方便靈活,可以根據需求自己修改。

參考 https://mxnet.incubator.apache.org/versions/master/tutorials/python/data_augmentation_with_masks.html

根據參考文檔中的內容,如果需要根據一個list讀取原始圖片,該list每行第一列是圖片路徑,第二列是圖片label。那麼可以參考下面的代碼

import mxnet as mx
from mxnet.gluon.data import dataset
from mxnet.gluon.data.vision import datasets, transforms
from mxnet import gluon, nd
import os
import cv2
import time
class readImageFromList(dataset.Dataset):
    def __init__(self, image_path, text_file, transform = None):
        self._transform = transform
        self._image_path = image_path
        self._text_file = text_file
        self._images = [line.strip("\n").split("\t")[0] for line in open(self._text_file, "r")]
        self._labels = [line.strip("\n").split("\t")[1] for line in open(self._text_file, "r")]
    def __getitem__(self, idx):
        file_name = os.path.join(self._image_path, self._images[idx])
        if os.path.isfile(file_name):
            image = mx.image.imread(file_name)
            #image = nd.random.uniform(shape = (3, 256, 256))
        else:
            print(file_name + "cannot found.")
        label = int(self._labels[idx])#這裏是否需要轉化爲tensor
        label = nd.array([label])
        if self._transform is not None:
            return self._transform(image), label
        else:
            return image, label
    def __len__(self):
        return len(self._images)

class imageTransform():
    def __init__(self):
        self.resize = mx.image.ResizeAug(256)
        self.crop = mx.image.RandomCropAug((224, 224))
        self.flip = mx.image.HorizontalFlipAug(p = 0.5)
        self.cast = mx.image.CastAug(typ = 'float32')
        self.bright = mx.image.BrightnessJitterAug(0.1)
        self.contrast = mx.image.ContrastJitterAug(0.1)
        self.color = mx.image.ColorJitterAug(0.1, 0.1, 0.1)
        self.rgb_mean = nd.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
        self.rgb_std = nd.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
    def __call__(self, image):
        img = self.resize(image)
        img = self.crop(img)
        img = self.flip(img)
        img = self.cast(img)
        img = self.color(img)
        img = img.transpose((2, 0, 1))
        img = (img.astype('float32') / 255 - self.rgb_mean) / self.rgb_std
        return img
if __name__ == "__main__":
    #transformer = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.13, 0.31)])
    transformer = imageTransform()
    image_train = readImageFromList(image_path = "dogvscat/train/", text_file = "train_list.txt", transform = transformer)
    batch_size = 128
    train_data = gluon.data.DataLoader(image_train, batch_size = batch_size, shuffle = True, num_workers = 1)
    for data, label in train_data:
        print(data.shape)
        print(label.shape)
        break

注意圖片的尺度要一致

5.使用NDArrayIter

暫時還沒有試過,但是NDArrayIter只能單線程,而ImageIter可以多線程。

import mxnet as mx 
import numpy as np 
import random 

batch_size = 5
dataset_length = 50

# random seeds
random.seed(1)
np.random.seed(1)
mx.random.seed(1)

train_data = np.random.rand(dataset_length, 28,28).astype('float32')
train_label = np.random.randint(0, 10, (dataset_length,)).astype('float32')

data_iter = mx.io.NDArrayIter(data=train_data, label=train_label, batch_size=batch_size, shuffle=False, data_name='data', label_name='softmax_label')
for batch in data_iter:
    print(batch.data[0].shape, batch.label[0])
    break

Appendix

從上述幾種載入數據的方式可以看到,載入方式主要分爲兩種

  • DataIter的傳統方式,返回DataBatch,有data和label兩個屬性的array。
  • Dataset + DataLoader的gluon方式,返回(data, label)的tuple

但是DataIter得到的數據無法直接用於DataLoader。使用gluon的時候推薦將DataIter轉換爲DatLoader可以加載的方式,但是augumentation這些操作不用太過在意(可以在DataIter中完成)。

一個簡單的類可以將DataIter對象打包成典型的gluon循環可以使用的類型。可以將該類對mxnet.image.ImageItermxnet.io.ImageRecordIter等對象使用。

class DataIterLoader():
    def __init__(self, data_iter):
        self.data_iter = data_iter

    def __iter__(self):
        self.data_iter.reset()
        return self

    def __next__(self):
        batch = self.data_iter.__next__()
        assert len(batch.data) == len(batch.label) == 1
        data = batch.data[0]
        label = batch.label[0]
        return data, label

    def next(self):
        return self.__next__() # for Python 2
data_iter = mx.io.NDArrayIter(data=X, label=y, batch_size=5)
data_iter_loader = DataIterLoader(data_iter)
for X_batch, y_batch in data_iter_loader:
    assert X_batch.shape == (5, 3)
    assert y_batch.shape == (5, 1)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章