PyTorch:學習torch.nn.functional.grid_sample

背景

最近在學習SfMLearner,其中一個非常重要的部分是Differentiable depth image-based rendering,翻譯過來就是基於深度的可微圖像渲染。這看起來好像很高大上,但是換句話說其實就是要根據深度,在當前影像上生成另一個視角的影像。不多說這個了,這其中一個比較重要的部分就是,雙線性採樣,論文裏圖示如下,ItI_t是目標影像(即前文說的另一個視角的影像),ItI_t是原始影像,ItI_t上一個整數座標ptp_t,根據深度投影到IsI_s上後得到一個浮點型座標psp_s;此時就要進行雙線性採樣,利用IsI_s上的四個點採樣出psp_s的值。
在這裏插入圖片描述
說到這裏,我基本說清楚了SfMLearner這個Rendering的原理,但是立馬就有一個問題出現在我的腦海裏,即這玩意在PyTorch裏應該怎麼實現。在查閱了ClementPinard的代碼SfmLearner-Pytorch後,發現對應的工具是porch.nn.functional.grid_sample

那問題就來了,這玩意到底怎麼用呢?後邊,我將配合代碼瞅瞅這玩意到底咋回事。當然,我的結論不一定完全對,但至少給一些表面的感覺吧。


參考鏈接

  1. https://pytorch.org/docs/master/nn.functional.html#grid-sample
  2. https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
  3. Unsupervised Learning of Depth and Ego-Motion from Video

代碼及說明

文檔

搞事之前,先簡單看一下文檔,如下:

torch.nn.functional.grid_sample(
input, 
grid, 
mode='bilinear', 
padding_mode='zeros', 
align_corners=None)
  1. 其中input的輸入格式是(N,C,Hin,Win)(N,C,H_{in},W_{in});其中NN對應的是Batch_Size,CC是通道數量;HinH_{in}WinW_{in}對應的是影像的高寬;這裏說明一下,還有一個5-D的輸入,我就不討論了。
  2. grid的輸入格式是(N,Hout,Wout,2)(N,H_{out},W_{out},2);這個NN對應的也是Batch_SizeHoutH_{out}WoutW_{out}分別對應的是grid的形狀,當然,輸出的形狀和這個也是對應的;最後就是這個2,表示的是grid上一個點的座標,分別是x座標和y座標,但注意座標值得範圍是歸一化後的[-1,+1]
  3. mode值得是採樣的方式,這個有nearestbilinearnearest就是最鄰近採樣,bilinear是雙線性插值。
  4. padding_mode 指的是邊緣的處理模式,包括zerosborderreflectionzeros指的是邊緣補充部分爲0border指的是邊緣補充部分直接複製邊緣區域,reflection指的是邊緣補充部分爲根據邊緣的鏡像,舉個例子ABC|CBA,,其中|是邊緣,CBA是原始圖像。
  5. align_corners,這個可以說是讓我腦袋最大的一個參數,看了半天都沒搞懂啥意思;經過多次看文檔和寫代碼測試,我終於有點明白了,這裏簡單解釋一下;當align_corners=True時,座標歸一化範圍是圖像四個角點的中心;當align_corners=False時,座標歸一化範圍是圖像四個角點靠外的角點;爲了更好的說明這個情況,我畫了一個大小爲3×33×3影像進行說明,如下,其中每一個方格代表一個像素,並且像素座標在方格中央;這個圖已經很清楚了吧,如果還不清楚,後邊還有代碼測試。
    在這裏插入圖片描述

代碼

爲了驗證之前的結論,以下有一些代碼進行測試。

  1. 生成一個3*3的輸入
test = torch.rand(1,1,3,3)
test[0][0][0][0]=1 
test[0][0][0][1]=2 
test[0][0][0][2]=3
test[0][0][1][0]=4 
test[0][0][1][1]=5 
test[0][0][1][2]=6
test[0][0][2][0]=7 
test[0][0][2][1]=8 
test[0][0][2][2]=9
print(test)

輸入

tensor([[[[1., 2., 3.],
          [4., 5., 6.],
          [7., 8., 9.]]]])
  1. 定義採樣格網,這裏爲了簡單,只有一個點
sample_one = torch.zeros(1,1,1,2)
sample_one [0][0][0][0] = -1  # x
sample_one [0][0][0][1] = -1  # y

sample_two = torch.zeros(1,1,1,2)
sample_two [0][0][0][0] = -2/3  # x
sample_two [0][0][0][1] = -2/3  # y

sample_thr = torch.zeros(1,1,1,2)
sample_thr [0][0][0][0] = -0.5  # x
sample_thr [0][0][0][1] = -0.5  # y
  1. 採樣
result_one = torch.nn.functional.grid_sample(test,sample,mode='bilinear',padding_mode="zeros",align_corners=True)
print(result_one )

result_two= torch.nn.functional.grid_sample(test,sample_two ,mode='bilinear',padding_mode="zeros",align_corners=False)
print(result_two)

result_thr= torch.nn.functional.grid_sample(test,sample_thr ,mode='bilinear',padding_mode="zeros",align_corners=True)
print(result_thr)

輸出

# result_one 
# 這個很好理解,當`align_corners=True`時,$(-1,-1)$對應的就是`test`的左上角點
tensor([[[[1.]]]]) 

# result_two
# 這個其實很奇葩,當`align_corners=False`時,$(-2/3,-2/3)$對應的纔是`test`的左上角點
tensor([[[[1.]]]])

# result_one 
# 當`align_corners=True`時,如果懂雙線性插值的話,直接算就可以得到1*0.5+2*0.5+4*0.5+5*0.5=3
tensor([[[[3.]]]])

總結

關於torch.nn.functional.grid_sample學習內容就如上了,後續繼續看warp的代碼,加油!!

發佈了152 篇原創文章 · 獲贊 161 · 訪問量 44萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章