[pytorch] Test time augmentation

[pytorch] Test time augmentation

1.什麼是Test time augmentation

train的時候我們經常加入data augmentation, 比如旋轉,對比度調整,gamma變換等等,其實本質上是爲了增加泛化性。在test的時候,同樣可以加入augmented images,相當於一個ensemble,模型分數也會有所提高。本文我寫了翻轉,旋轉90°倍數的TTA方法。

2.pytorch

# -*- coding: utf-8 -*-
# @Time    : 2020/1/10 12:20
# @Author  : Mingxing Li
# @FileName: fusion.py
# @Software: PyCharm

# import network1.twonet
# import network2.twonet
import torch

class Test_time_agumentation(object):

    def __init__(self, is_rotation=True):
        self.is_rotation = is_rotation

    def __rotation(self, img):
        """
        clockwise rotation 90 180 270
        """
        img90 = img.rot90(-1, [2, 3]) # 1 逆時針; -1 順時針
        img180 = img.rot90(-1, [2, 3]).rot90(-1, [2, 3])
        img270 = img.rot90(1, [2, 3])
        return [img90, img180, img270]

    def __inverse_rotation(self, img90, img180, img270):
        """
        anticlockwise rotation 90 180 270
        """
        img90 = img90.rot90(1, [2, 3]) # 1 逆時針; -1 順時針
        img180 = img180.rot90(1, [2, 3]).rot90(1, [2, 3])
        img270 = img270.rot90(-1, [2, 3])
        return img90, img180, img270

    def __flip(self, img):
        """
        Flip vertically and horizontally
        """
        return [img.flip(2), img.flip(3)]

    def __inverse_flip(self, img_v, img_h):
        """
        Flip vertically and horizontally
        """
        return img_v.flip(2), img_h.flip(3)

    def tensor_rotation(self, img):
        """
        img size: [H, W]
        rotation degree: [90 180 270]
        :return a rotated list
        """
        # assert img.shape == (1024, 1024)
        return self.__rotation(img)

    def tensor_inverse_rotation(self, img_list):
        """
        img size: [H, W]
        rotation degree: [90 180 270]
        :return a rotated list
        """
        # assert img.shape == (1024, 1024)
        return self.__inverse_rotation(img_list[0], img_list[1], img_list[2])

    def tensor_flip(self, img):
        """
        img size: [H, W]
        :return a flipped list
        """
        # assert img.shape == (1024, 1024)
        return self.__flip(img)

    def tensor_inverse_flip(self, img_list):
        """
        img size: [H, W]
        :return a flipped list
        """
        # assert img.shape == (1024, 1024)
        return self.__inverse_flip(img_list[0], img_list[1])


if __name__ == "__main__":
    a = torch.tensor([[0, 1],[2, 3]]).unsqueeze(0).unsqueeze(0)
    print(a)
    tta = Test_time_agumentation()
    # a = tta.tensor_rotation(a)
    a = tta.tensor_flip(a)
    print(a)
    a = tta.tensor_inverse_flip(a)
    print(a)

同時我將代碼release到了https://github.com/Limingxing00/Test-time-augmentation。
代碼後續更新將在GitHub上。歡迎大家討論交流!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章