pytorch源碼 folder

在使用pytorch構建數據庫時,會使用到ImageFolder這個模塊便於數據加載,瞭解其源碼便於快速開發。

import torch.utils.data as data
#PIL: Python Image Library縮寫,圖像處理模塊
#     Image,ImageFont,ImageDraw,ImageFilter
from PIL import Image    
import os
import os.path

# 圖片擴展(圖片格式)
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

# 判斷是不是圖片文件
def is_image_file(filename):
    # 只要文件以IMG_EXTENSIONS結尾,就是圖片
    # 注意any的使用
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

# 結果:classes:['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# classes_to_idx:{'1': 1, '0': 0, '3': 3, '2': 2, '5': 5, '4': 4, '7': 7, '6': 6, '9': 9, '8': 8}
def find_classes(dir):
    '''
     返回dir下的類別名,classes:所有的類別,class_to_idx:將文件中str的類別名轉化爲int類別
     classes爲目錄下所有文件夾名字的集合
    '''
    # os.listdir:以列表的形式顯示當前目錄下的所有文件名和目錄名,但不會區分文件和目錄。
    # os.path.isdir:判定對象是否是目錄,是則返回True,否則返回False
    # os.path.join:連接目錄和文件名

    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    # sort:排序
    classes.sort()
    # 將文件名中得到的類別轉化爲數字class_to_idx['3'] = 3
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx
    # class_to_idx :{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}


# 如果文件是圖片文件,則保留它的路徑,和索引至images(path,class_to_idx)
def make_dataset(dir, class_to_idx):
    # 返回(圖片的路徑,圖片的類別)
    # 打開文件夾,一個個索引
    images = []
    # os.path.expanduser(path):把path中包含的"~"和"~user"轉換成用戶目錄
    dir = os.path.expanduser(dir)
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        # os.walk:遍歷目錄下所有內容,產生三元組
        # (dirpath, dirnames, filenames)【文件夾路徑, 文件夾名字, 文件名】
        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)   # 圖片的路徑
                    item = (path, class_to_idx[target])  # (圖片的路徑,圖片類別)
                    images.append(item)

    return images

# 打開路徑下的圖片,並轉化爲RGB模式
def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    # with as : 安全方面,可替換:try,finally
    # 'r':以讀方式打開文件,可讀取文件信息
    # 'b':以二進制模式打開文件,而不是文本
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            # convert:,用於圖像不同模式圖像之間的轉換,這裏轉換爲‘RGB’
            return img.convert('RGB')


def accimage_loader(path):
    # accimge:高性能圖像加載和增強程序模擬的程序。
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    # get_image_backend:獲取加載圖像的包的名稱
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


class ImageFolder(data.Dataset):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    # 初始化,繼承參數
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        # TODO
        # 1. Initialize file path or list of file names.
        # 找到root的文件和索引
        classes, class_to_idx = find_classes(root)
        # 保存路徑下圖片文件路徑和索引至imgs
        imgs = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #這裏需要注意的是,第一步:read one data,是一個data

        path, target = self.imgs[index] 
        # 這裏返回的是圖片路徑,而需要的是圖片格式
        img = self.loader(path) # 將圖片路徑加載成所需圖片格式
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        # return the total size of your dataset.
        return len(self.imgs)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章