[pytorch] Test time augmentation

[pytorch] Test time augmentation

1.什么是Test time augmentation

train的时候我们经常加入data augmentation, 比如旋转,对比度调整,gamma变换等等,其实本质上是为了增加泛化性。在test的时候,同样可以加入augmented images,相当于一个ensemble,模型分数也会有所提高。本文我写了翻转,旋转90°倍数的TTA方法。

2.pytorch

# -*- coding: utf-8 -*-
# @Time    : 2020/1/10 12:20
# @Author  : Mingxing Li
# @FileName: fusion.py
# @Software: PyCharm

# import network1.twonet
# import network2.twonet
import torch

class Test_time_agumentation(object):

    def __init__(self, is_rotation=True):
        self.is_rotation = is_rotation

    def __rotation(self, img):
        """
        clockwise rotation 90 180 270
        """
        img90 = img.rot90(-1, [2, 3]) # 1 逆时针; -1 顺时针
        img180 = img.rot90(-1, [2, 3]).rot90(-1, [2, 3])
        img270 = img.rot90(1, [2, 3])
        return [img90, img180, img270]

    def __inverse_rotation(self, img90, img180, img270):
        """
        anticlockwise rotation 90 180 270
        """
        img90 = img90.rot90(1, [2, 3]) # 1 逆时针; -1 顺时针
        img180 = img180.rot90(1, [2, 3]).rot90(1, [2, 3])
        img270 = img270.rot90(-1, [2, 3])
        return img90, img180, img270

    def __flip(self, img):
        """
        Flip vertically and horizontally
        """
        return [img.flip(2), img.flip(3)]

    def __inverse_flip(self, img_v, img_h):
        """
        Flip vertically and horizontally
        """
        return img_v.flip(2), img_h.flip(3)

    def tensor_rotation(self, img):
        """
        img size: [H, W]
        rotation degree: [90 180 270]
        :return a rotated list
        """
        # assert img.shape == (1024, 1024)
        return self.__rotation(img)

    def tensor_inverse_rotation(self, img_list):
        """
        img size: [H, W]
        rotation degree: [90 180 270]
        :return a rotated list
        """
        # assert img.shape == (1024, 1024)
        return self.__inverse_rotation(img_list[0], img_list[1], img_list[2])

    def tensor_flip(self, img):
        """
        img size: [H, W]
        :return a flipped list
        """
        # assert img.shape == (1024, 1024)
        return self.__flip(img)

    def tensor_inverse_flip(self, img_list):
        """
        img size: [H, W]
        :return a flipped list
        """
        # assert img.shape == (1024, 1024)
        return self.__inverse_flip(img_list[0], img_list[1])


if __name__ == "__main__":
    a = torch.tensor([[0, 1],[2, 3]]).unsqueeze(0).unsqueeze(0)
    print(a)
    tta = Test_time_agumentation()
    # a = tta.tensor_rotation(a)
    a = tta.tensor_flip(a)
    print(a)
    a = tta.tensor_inverse_flip(a)
    print(a)

同时我将代码release到了https://github.com/Limingxing00/Test-time-augmentation。
代码后续更新将在GitHub上。欢迎大家讨论交流!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章