PyTorch(1.3.0+):學習torch.nn.functional.grid_sample

背景

最近在學習SfMLearner,其中一個非常重要的部分是Differentiable depth image-based rendering,翻譯過來就是基於深度的可微圖像渲染。這看起來好像很高大上,但是換句話說其實就是要根據深度,在當前影像上生成另一個視角的影像。不多說這個了,這其中一個比較重要的部分就是,雙線性採樣,論文裏圖示如下,ItI_t是目標影像(即前文說的另一個視角的影像),ItI_t是原始影像,ItI_t上一個整數座標ptp_t,根據深度投影到IsI_s上後得到一個浮點型座標psp_s;此時就要進行雙線性採樣,利用IsI_s上的四個點採樣出psp_s的值。
在這裏插入圖片描述
說到這裏,我基本說清楚了SfMLearner這個Rendering的原理,但是立馬就有一個問題出現在我的腦海裏,即這玩意在PyTorch裏應該怎麼實現。在查閱了ClementPinard的代碼SfmLearner-Pytorch後,發現對應的工具是porch.nn.functional.grid_sample

那問題就來了,這玩意到底怎麼用呢?後邊,我將配合代碼瞅瞅這玩意到底咋回事。當然,我的結論不一定完全對,但至少給一些表面的感覺吧。


參考鏈接

  1. https://pytorch.org/docs/master/nn.functional.html#grid-sample
  2. https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
  3. Unsupervised Learning of Depth and Ego-Motion from Video

代碼及說明

文檔

搞事之前,先簡單看一下文檔,如下:

torch.nn.functional.grid_sample(
input, 
grid, 
mode='bilinear', 
padding_mode='zeros', 
align_corners=None)
  1. 其中input的輸入格式是(N,C,Hin,Win)(N,C,H_{in},W_{in});其中NN對應的是Batch_Size,CC是通道數量;HinH_{in}WinW_{in}對應的是影像的高寬;這裏說明一下,還有一個5-D的輸入,我就不討論了。
  2. grid的輸入格式是(N,Hout,Wout,2)(N,H_{out},W_{out},2);這個NN對應的也是Batch_SizeHoutH_{out}WoutW_{out}分別對應的是grid的形狀,當然,輸出的形狀和這個也是對應的;最後就是這個2,表示的是grid上一個點的座標,分別是x座標和y座標,但注意座標值得範圍是歸一化後的[-1,+1]
  3. mode值得是採樣的方式,這個有nearestbilinearnearest就是最鄰近採樣,bilinear是雙線性插值。
  4. padding_mode 指的是邊緣的處理模式,包括zerosborderreflectionzeros指的是邊緣補充部分爲0border指的是邊緣補充部分直接複製邊緣區域,reflection指的是邊緣補充部分爲根據邊緣的鏡像,舉個例子ABC|CBA,,其中|是邊緣,CBA是原始圖像。
  5. align_corners,這個可以說是讓我腦袋最大的一個參數,看了半天都沒搞懂啥意思;經過多次看文檔和寫代碼測試,我終於有點明白了,這裏簡單解釋一下;當align_corners=True時,座標歸一化範圍是圖像四個角點的中心;當align_corners=False時,座標歸一化範圍是圖像四個角點靠外的角點;爲了更好的說明這個情況,我畫了一個大小爲3×33×3影像進行說明,如下,其中每一個方格代表一個像素,並且像素座標在方格中央;這個圖已經很清楚了吧,如果還不清楚,後邊還有代碼測試。
    在這裏插入圖片描述

代碼

爲了驗證之前的結論,以下有一些代碼進行測試。

  1. 生成一個3*3的輸入
test = torch.rand(1,1,3,3)
test[0][0][0][0]=1 
test[0][0][0][1]=2 
test[0][0][0][2]=3
test[0][0][1][0]=4 
test[0][0][1][1]=5 
test[0][0][1][2]=6
test[0][0][2][0]=7 
test[0][0][2][1]=8 
test[0][0][2][2]=9
print(test)

輸入

tensor([[[[1., 2., 3.],
          [4., 5., 6.],
          [7., 8., 9.]]]])
  1. 定義採樣格網,這裏爲了簡單,只有一個點
sample_one = torch.zeros(1,1,1,2)
sample_one [0][0][0][0] = -1  # x
sample_one [0][0][0][1] = -1  # y

sample_two = torch.zeros(1,1,1,2)
sample_two [0][0][0][0] = -2/3  # x
sample_two [0][0][0][1] = -2/3  # y

sample_thr = torch.zeros(1,1,1,2)
sample_thr [0][0][0][0] = -0.5  # x
sample_thr [0][0][0][1] = -0.5  # y
  1. 採樣
result_one = torch.nn.functional.grid_sample(test,sample,mode='bilinear',padding_mode="zeros",align_corners=True)
print(result_one )

result_two= torch.nn.functional.grid_sample(test,sample_two ,mode='bilinear',padding_mode="zeros",align_corners=False)
print(result_two)

result_thr= torch.nn.functional.grid_sample(test,sample_thr ,mode='bilinear',padding_mode="zeros",align_corners=True)
print(result_thr)

輸出

# result_one 
# 這個很好理解,當`align_corners=True`時,$(-1,-1)$對應的就是`test`的左上角點
tensor([[[[1.]]]]) 

# result_two
# 這個其實很奇葩,當`align_corners=False`時,$(-2/3,-2/3)$對應的纔是`test`的左上角點
tensor([[[[1.]]]])

# result_one 
# 當`align_corners=True`時,如果懂雙線性插值的話,直接算就可以得到1*0.5+2*0.5+4*0.5+5*0.5=3
tensor([[[[3.]]]])

補充

補充兩個代碼,分別是在二維影像上進行一次採樣和多次採樣

class SampleFeatureSingle(nn.Module):
    def __init__(self):
        super(SampleFeatureSingle, self).__init__()
        
    def forward(self,feature,x_move,y_move):
        
        b,c,h,w = feature.shape
                
        x_range = torch.arange(0, w).view(1, 1, w).expand(b,h,w).float() + x_move
        y_range = torch.arange(0, h).view(1, h, 1).expand(b,h,w).float() + y_move

        x_range = 2.*x_range/(w-1) - 1
        y_range = 2.*y_range/(h-1) - 1
        
        grid = torch.stack((x_range,y_range), dim=3)
        
        sample = F.grid_sample(feature,grid,mode='bilinear',padding_mode="zeros",align_corners=True)
        
        return sample

class SampleFeatureMulti(nn.Module):
    def __init__(self):
        super(SampleFeatureMulti, self).__init__()
        
    def forward(self,feature,x_move,y_move):
        
        b,c,t,h,w = feature.shape
        bd,td,hd,wd = x_move.shape
        
        x_range = torch.arange(0, w).view(1, 1, w).expand(b,1,h,w).expand(b,td,h,w).float().cuda() + x_move
        y_range = torch.arange(0, h).view(1, h, 1).expand(b,1,h,w).expand(b,td,h,w).float().cuda() + y_move
        z_range = torch.arange(0, td).view(1, td, 1, 1).expand(b,td,h,w).float().cuda()
        
        x_range = 2.*x_range/(w-1) - 1
        y_range = 2.*y_range/(h-1) - 1
        z_range = 2.*z_range/(td-1) - 1
        
        grid = torch.stack((x_range,y_range,z_range), dim=4)
        
        sample = F.grid_sample(feature,grid,mode='bilinear',padding_mode="zeros",align_corners=True)
        
        return sample

總結

關於torch.nn.functional.grid_sample學習內容就如上了,後續繼續看warp的代碼,加油!!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章