Dense Deformation(Elastic deformation)的採樣與可視化
對於採樣(sample)而言,分爲前向與後向。前向直觀上更容易理解,但實際應用會有問題,因此,算法中的採樣操作一般爲後向。
採樣
前向採樣:
- 將輸入圖像座標
p1
根據DDF(Dense deformation Feild)
或Affine Matrix
或其他方式進行變換得到座標p2
,輸入圖像p1
座標處的值 映射到輸出圖像p2
處。
僞代碼示例:
#與input_image shape一致的空矩陣
output_image = empty_array()
for row in input_image.height:
for col in input_image.width:
# 或者 dx,dy = Affinematrix*[row,col]
dx,dy = random(),random()
output_image[row+dy,col+dx] = input_image[row,col]
這樣會出現的問題是,output_image中,很多座標上沒有值。而後向採樣就能解決這個問題。
後向採樣:
- 將輸出圖像座標
p1
根據DDF(Dense deformation Feild)
或Affine Matrix的逆
或其他方式進行變換得到座標p2
,輸出圖像p1
座標處的值 根據輸入圖像p2
座標處經過插值算法得到的像素值確定。
僞代碼示例:
#與input_image shape一致的空矩陣
output_image = empty_array()
for row in output_image.height:
for col in output_image.width:
# 或者 dx,dy = Affinematrix_inverse*[row,col]
dx,dy = random(),random()
# 根據插值算法,得到input_image中(row+dy,col+dx)處的像素值
value = interpolate(input_image,row+dy,col+dx)
output_image[row,col] = value
由僞代碼也易看出,後向採樣可以確保輸出圖像中每個座標都有對應的像素值(若row+dy或col+dx超出輸入圖像範圍,value=None,則是變形的結果,而不是輸出圖像這一座標沒有值)
變形場可視化
這一部分是我自己的思考,不一定正確,若有錯誤,請指正。
這兒的變形場是針對DDF的,可視化我認爲有兩種形式,
- 一種是以網格線變形的形式(圖1)。第一種形式的可視化我是通過在圖像上畫線,然後對畫了線的圖像做變形,最後可視化出來。1 結果自然就能夠呈現出網格線變形的效果(不太清楚爲什麼出現了很多虛線,記得幾個月前看過一篇博客說這種原因,以及怎麼處理,現在忘記了)。
- 另一種是用箭頭可視化變形矢量場(圖2)。直接通過
matplotlib
的quiver
接口實現。(但我認爲這兒畫的有問題,因爲我是通過反向採樣實現的變形,例如輸入圖像(1,1)處的值要變形到輸出圖像的(2,2)處,根據直觀感受來說,此處的箭頭應該是位於(1,1)處,指向(2,2)處,但這兒有的數據則是輸出圖像(2,2)處爲(-1,-1),即dx=-1,dy=-1,會表現爲出現在(2,2)處指向(1,1)的箭頭,雖然可以通過取反調轉箭頭方向,然而箭尾仍舊出現在(2,2)處)
變形與可視化實例
環境 python3, jupyter notebook
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
import cv2
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
def draw_grid(im, grid_size):
# Draw grid lines
for i in range(0, im.shape[1], grid_size):
cv2.line(im, (i, 0), (i, im.shape[0]), color=(1,))
for j in range(0, im.shape[0], grid_size):
cv2.line(im, (0, j), (im.shape[1], j), color=(1,))
"""
形式一
"""
image = cv2.imread('./r.jpeg',0)
draw_grid(image,40)
shape = image.shape
sigma = 10
alpha = 180
random_state = np.random.RandomState(1)
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
x,y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))
# map_coordinates函數的作用:根據indices中指定的座標,在image中進行採樣
#也就相當於反向採樣
image_deform = map_coordinates(image, indices, order=1, mode='reflect').reshape(shape)
plt.imshow(image_deform,cmap='gray')
"""
---------------------------------------------------------------
形式二
"""
image = cv2.imread('./r.jpeg',0)
shape = image.shape
sigma = 10
alpha = 180
random_state = np.random.RandomState(1)
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
x,y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
# 注意,這兒用np的slice技巧對密集的點進行了下采樣,否則畫出來的箭頭
#完全不能看
plt.quiver(x[::30,::30], y[::30,::30], -dx[::30,::30], -dy[::30,::30],color='red')
indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))
image_deform = map_coordinates(image, indices, order=2, mode='reflect').reshape(shape)
plt.imshow(image_deform,cmap='gray')
STN(Spatial Transformer Network)
stn作爲一個將transform層化的網絡,有很多值得學習的地方。這兒提及這個網絡,是因爲其採樣機理就是上面提及過的後向採樣(雖然我只看過pytorch的,其餘日後有時間看了再補充),涉及兩個關鍵函數F.affine_grid
, F.grid_sample
2。
用pytorch做小實驗的關鍵部分如下:3
from torch.nn import functional as F
# 這個theta的含義是圖像向右平移20%,向下平移40%
theta = torch.tensor([
[1,0,-0.2],
[0,1,-0.4]
], dtype=torch.float)
# 修改size
N, C, W, H = img_torch.unsqueeze(0).size()
size = torch.Size((N, C, W//2, H//3))
"""
注意這個grid,經過試驗,發現相當於上文出現過的indices,即從輸出映射到輸入,是反向採樣原理。
因爲其x方向的輸出如下
#print(grid.size()) torch.Size([1, 10, 10, 2])
print(grid[0,:,:,0]) #輸出x方向的coordinate
tensor([[-1.2000, -0.9778, -0.7556, -0.5333, -0.3111, -0.0889, 0.1333, 0.3556,
0.5778, 0.8000],
[-1.2000, -0.9778, -0.7556, -0.5333, -0.3111, -0.0889, 0.1333, 0.3556,
0.5778, 0.8000],.......
"""
grid = F.affine_grid(theta.unsqueeze(0), size)
output = F.grid_sample(img_torch.unsqueeze(0), grid)