Pytorch的若干非奇技淫巧

數據處理

from skimage import io
img = io.imread(filename, as_gray=True)

注意這裏as_gray=True返回的img其實是除以255的,所以範圍是[0,1]

import numpy as np
import torch

xx, yy = np.meshgrid(np.arange(w), np.arange(h))
yy, xx = torch.meshgrid(torch.arange(w), torch.arange(h))

注意numpytorchmeshgrid不同,返回的座標順序剛好相反

import torch
import torchvision

transform = torchvision.transforms.Compose([
	torchvision.transforms.ToTensor(),
	torchvision.transforms.Normalize(mean=(,), std=(,))

])

ToTensor是把圖片轉化成浮點型,也就是[0,1]範圍內,而Normalize則是用高斯正則化,tensor = (tensor - mean) / (std + eps),如果mean=0.5, std=0.5的話,則相當於tensor = tensor * 2 - 1,將數據歸一化到[-1,1]

import os
import os.path as osp
import torchvision
from skimage import io

filename = osp.join(self.hparams.log_dir, "%s_%s_%s.png" % \
 (self.global_step, batch_idx, random_idx))
imgs = torchvision.utils.make_grid(imgs, \ 
							nrow=V).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)
io.imsave(filename, imgs)

注意這裏的nrow指的並不是行數,而是列數

張量操作

import torch

a = torch.randn(3,10)
b = torch.randn(3,10)

c = torch.cat((a, b), dim=0) # [6, 10]
d = torch.stack((a, b), dim=0) # [2, 3, 10]

cat會在dim維上進行拼接,而stack意爲堆疊,會增加新的一維,stack常用於圖片的堆疊

節省顯存

  • 儘量使用inplace,原地操作,不再複製張量
    def inplace_relu(m):
        classname = m.__class__.__name__
        if classname.find('ReLU') != -1:
            m.inplace=True
    
    model.apply(inplace_relu)
  • 不做梯度傳播
    with torch.no_grad():
    	# 一些不需要梯度計算的運算,比如純數學運算
  • 刪除臨時張量,釋放cache顯存
    del r'''一些不再需要的臨時張量'''
    torch.cuda.empty_cache()
  • 使用非確定性算法,尋找cuDNN最優配置,優化框架效率
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

待更

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章