TF code: https://github.com/kevinzakka/spatial-transformer-network
一、相關背景
如果網絡能夠對經過平移、旋轉、縮放及裁剪等操作的圖片得到與未經變換前相同的檢測結果,我們就說這個網絡具有空間變換不變性(將平移、旋轉、縮放及裁剪不變性統稱爲空間不變性)。具有空間變換不變性的網絡能夠得到更精確地分類結果。傳統CNN網絡的池化層具有平移不變性(網絡在平移小於池化矩陣的範圍時具有平移不變性。所以只有平移小於這個範圍,才能保證平移不變性。),但是CNN網絡對於大尺度的空間變換並不具備不變性。Spatial Transformer Networks提出的空間網絡變換層,具有平移不變性、旋轉不變性及縮放不變性等強大的性能。這個網絡可以加在現有的卷積網絡中,提高分類的準確性。
如下圖所示:輸入手寫字體,我們感興趣的是黃色框中的包含數字的區域,那麼在訓練的過程中,學習到的空間變換網絡會自動提取黃色框中的局部數據特徵,並對框內的數據進行空間變換,得到輸出output。綜上所述,空間變換網絡主要有如下三個作用:
- 可以將輸入轉換爲下一層期望的形式
- 可以在訓練的過程中自動選擇感興趣的區域特徵
- 可以實現對各種形變的數據進行空間變換
圖1.空間變換網絡作用示意圖
二、相關理論
在理解STN之前,先簡單瞭解一下基本的仿射變換、雙線性插值。
- 仿射變換(Affine transformation)
下面的所有變換假設都是針對一幅圖像,即一個三維數組(H*W*C),這裏爲簡單起見,假設圖像都是單通道(C=1)的。首先說明一下待會要用到的符號:
- (x,y): 原圖像中某一點A的位置
- (x′,y′): 變換後圖像中A點對應的位置
平移(translation)
若將原圖像沿x和y方向分別平移 和 ,即:
寫成矩陣形式如下:
縮放(Scaling)
假設將圖像分別沿x和y方向分別縮放p倍和q倍,且p>0,q>0,即:
寫成矩陣形式如下:
旋轉(Rotation)
圖2.旋轉變換示意圖
如上圖所示,點A旋轉θ角到點B,由B點可得
由A點可得:
整理可得
寫成矩陣形式如下:
剪切(Shear)
剪切變換指的是類似於四邊形不穩定性那種性質,方形變平行四邊形。任意一邊都可以被拉長,以一定比例的x補償y,也以一定比例的y補償x。
仿射變換(Affine transformation)
其實上面幾種常見變換都可以用同一種變換來表示,就是仿射變換,它有更一般的形式,如下:
a,b,c,d,e,f取不同的值就可以表示上述不同的變換。當6個參數取其上述變換以外的值時,爲一般的仿射變換,效果相當於從不同的位置看同一個目標。
2.雙線性插值(Bilinear Interpolation)
在對圖像進行仿射變換時,會出現一個問題,當原圖像中某一點的座標映射到變換後圖像時,座標可能會出現小數,而我們知道,圖像上某一像素點的位置座標只能是整數,那該怎麼辦?這時候雙線性插值就起作用了。在介紹雙線性插值之前,先講一下線性插值的計算方法:已知點 (x0, y0) 與 (x1, y1),要計算 [x0, x1] 區間內某一位置 x 在直線上的y值,可以採用兩點式寫出直線方程並求得y的取值如下:
雙線性插值的基本思想是通過某一點周圍四個點的灰度值來估計出該點的灰度值,如圖3所示.
圖3.雙線性插值示意圖
已知Q11、Q12、Q21、Q22四點的座標,要求點P的座標。分成兩步,首先在 x 方向進行線性插值,得到:
然後在 y 方向進行線性插值,得到:
由於圖像雙線性插值只會用相鄰的4個點,因此上述公式的分母都是1。整合上述公式有:
三、算法概述
STN網絡包括三部分:
- Localisation Network-局部網絡
- Parameterised Sampling Grid-參數化網格採樣
- Differentiable Image Sampling-差分圖像採樣
- Localisation Network-局部網絡
輸入:特徵圖
輸出:變換矩陣 ,用於下一步計算( 輸出規模視具體的變換。以仿射變換爲例, 是一個[2,3]大小的6維參數)
注: 被初始化爲恆等變換矩陣,通過損失函數不斷更正的參數,最終得到期望的仿射變換矩陣。得到輸出特徵圖後最重要的是得到輸出特徵圖每個位置的像素值。(圖像對於計算機來說就是一個0-255的像素值組成的矩陣,圖像經過空間變換後每個點的像素值肯定會發生變化,下面就介紹如何確定變換後的特徵圖每個位置的像素值)
2. Parameterised Sampling Grid-參數化網格採樣
此步驟的目地是爲了得到輸出特徵圖的座標點對應的輸入特徵圖的座標點的位置。計算方式如下:
式中s代表輸入特徵圖像座標點,t代表輸出特徵圖座標點, 是局部網絡的輸出。這裏需要注意的是座標的映射關係是從目標圖片——>輸入圖片。這是因爲輸入圖片與目標圖片座標點均是人爲定義的標準化格點矩陣,x,y的值在-1到1之間,圖片任何一個位置的座標點是固定不變的。這就好比兩個座標完全一樣的圖像,無論用誰乘以仿射變換矩陣,都可以得到經過仿射變換後的圖像與原座標點的映射關係。也就是說這裏即使把座標的映射關係變爲輸入圖片——>目標圖片得到的也是一樣的映射關係。至於爲什麼要使用前者來求解這種映射關係,個人理解的是目標圖片是我們期望的輸出,我們通常以輸出爲參考,依次獲得目標圖片在每個座標點的像素值。比如目標圖片座標點(0,0)對應輸入圖片座標點(3,1),我們就先取出輸入圖片座標點(3,1)處的像素值,這樣依次獲得目標圖片在每個座標點的像素值。通過上面的解釋相信你們也能理解爲什麼沒有使用仿射變換的逆矩陣。
通過這一步,我們已經得到變換後的輸出特徵圖每個位置的座標在輸入特徵圖上的對應座標點。下面我們就可以直接提取出輸入特徵圖的每個位置的像素值(tensorflow有專門的函數可以得到指定位置的像素值)。在提取像素值之前,我們應該注意到一點:目標圖片的座標點對應的輸入圖片的座標點不一定是整數座標點(例如目標圖片座標點(0,1)對應輸入圖片座標點(3.2,1.3)),而僅僅整數座標才能提取像素值,所以需要利用插值的方式來計算出對應該點的灰度值(像素值)。可以看出,步驟一爲步驟二提供了仿射變換的矩陣,步驟二爲步驟三提供了輸出特徵圖的座標點對應的輸入特徵圖的座標點的位置,步驟三只需要提取這個對應的座標點的像素值(非整數座標需要使用雙向性插值提取像素值)就能最終得到輸出特徵圖V。
左圖爲輸出特徵圖 右圖爲輸入特徵圖
3.Differentiable Image Sampling-差分圖像採樣
這一步完成的任務就是利用期望的插值方式來計算出對應點的灰度值。這裏以雙向性插值爲例講解,論文中給出了雙向性插值的計算公式如下:
爲輸出特徵圖上第c個通道某一點的灰度值, 爲輸入特徵圖上第c個通道點(n,m)的灰度。當或者大於1時,對應的max()項將取0,也就是說,只有 周圍4個點的灰度值決定目標像素點的灰度。並且當 和 越小,影響越大(即離點 (n,m)越近),權重越大,這和我們上面介紹雙線性插值的結論是一致的。其實,這個式子等價於下式:
四、總結及代碼實現-代碼下載
1.Spatial Transformer Networks代碼實現
def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
print('begin-transformer')
def _repeat(x, n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(
tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0])
rep = tf.cast(rep, 'int32')
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
return tf.reshape(x, [-1])
def _interpolate(im, x, y, out_size):
with tf.variable_scope('_interpolate'):
# constants
num_batch = tf.shape(im)[0]
height = tf.shape(im)[1]
width = tf.shape(im)[2]
channels = tf.shape(im)[3]
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = out_size[0]
out_width = out_size[1]
zero = tf.zeros([], dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
# scale indices from [-1, 1] to [0, width/height]
x = (x + 1.0) * (width_f) / 2.0
y = (y + 1.0) * (height_f) / 2.0
# do sampling
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
dim2 = width
dim1 = width * height
base = _repeat(tf.range(num_batch) * dim1, out_height * out_width)
base_y0 = base + y0 * dim2
base_y1 = base + y1 * dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# use indices to lookup pixels in the flat image and restore
# channels dim
im_flat = tf.reshape(im, tf.stack([-1, channels]))
im_flat = tf.cast(im_flat, 'float32')
Ia = tf.gather(im_flat, idx_a)
Ib = tf.gather(im_flat, idx_b)
Ic = tf.gather(im_flat, idx_c)
Id = tf.gather(im_flat, idx_d)
# and finally calculate interpolated values
x0_f = tf.cast(x0, 'float32')
x1_f = tf.cast(x1, 'float32')
y0_f = tf.cast(y0, 'float32')
y1_f = tf.cast(y1, 'float32')
wa = tf.expand_dims(((x1_f - x) * (y1_f - y)), 1)
wb = tf.expand_dims(((x1_f - x) * (y - y0_f)), 1)
wc = tf.expand_dims(((x - x0_f) * (y1_f - y)), 1)
wd = tf.expand_dims(((x - x0_f) * (y - y0_f)), 1)
output = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
return output
def _meshgrid(height, width):
print('begin--meshgrid')
with tf.variable_scope('_meshgrid'):
# This should be equivalent to:
# x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
# np.linspace(-1, 1, height))
# ones = np.ones(np.prod(x_t.shape))
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
print('meshgrid_x_t_ok')
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
tf.ones(shape=tf.stack([1, width])))
print('meshgrid_y_t_ok')
x_t_flat = tf.reshape(x_t, (1, -1))
y_t_flat = tf.reshape(y_t, (1, -1))
print('meshgrid_flat_t_ok')
ones = tf.ones_like(x_t_flat)
print('meshgrid_ones_ok')
print(x_t_flat)
print(y_t_flat)
print(ones)
grid = tf.concat([x_t_flat, y_t_flat, ones], 0)
print('over_meshgrid')
return grid
def _transform(theta, input_dim, out_size):
print('_transform')
with tf.variable_scope('_transform'):
num_batch = tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
width = tf.shape(input_dim)[2]
num_channels = tf.shape(input_dim)[3]
theta = tf.reshape(theta, (-1, 2, 3))
theta = tf.cast(theta, 'float32')
# grid of (x_t, y_t, 1), eq (1) in ref [1]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = out_size[0]
out_width = out_size[1]
grid = _meshgrid(out_height, out_width)
grid = tf.expand_dims(grid, 0)
grid = tf.reshape(grid, [-1])
grid = tf.tile(grid, tf.stack([num_batch]))
grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))
# tf.batch_matrix_diag
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
print('begin--batch--matmul')
T_g = tf.matmul(theta, grid)
x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
x_s_flat = tf.reshape(x_s, [-1])
y_s_flat = tf.reshape(y_s, [-1])
input_transformed = _interpolate(
input_dim, x_s_flat, y_s_flat,
out_size)
output = tf.reshape(
input_transformed, tf.stack([num_batch, out_height, out_width, num_channels]))
print('over_transformer')
return output
with tf.variable_scope(name):
output = _transform(theta, U, out_size)
return output
def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
with tf.variable_scope(name):
num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
indices = [[i] * num_transforms for i in xrange(num_batch)]
input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
return transformer(input_repeated, thetas, out_size)
2.STN網絡測試代碼
from scipy import ndimage
import tensorflow as tf
from STN_tf_01 import transformer
import numpy as np
import matplotlib.pyplot as plt
import cv2
im = ndimage.imread('C:\\Users\julie\Desktop\cat.jpg')#改爲你自己要測試的圖片路徑
im = im / 255.
# im=tf.reshape(im, [1,1200,1600,3])
im = im.reshape(1, 1200, 1600, 3)
im = im.astype('float32')
print('img-over')
out_size = (600, 800)
batch = np.append(im, im, axis=0)
batch = np.append(batch, im, axis=0)
num_batch = 3
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
x = tf.cast(batch, 'float32')
print('begin---')
with tf.variable_scope('spatial_transformer_0'):
n_fc = 6
w_fc1 = tf.Variable(tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1'))
initial = np.array([[0.5, 0, 0], [0, 0.5, 0]])
initial = initial.astype('float32')
initial = initial.flatten()
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
h_fc1 = tf.matmul(tf.zeros([num_batch, 1200 * 1600 * 3]), w_fc1) + b_fc1
print(x, h_fc1, out_size)
h_trans = transformer(x, h_fc1, out_size)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
y = sess.run(h_trans, feed_dict={x: batch})
plt.imshow(y[0])
plt.show()
效果如下:
輸入圖片
經過STN網絡的圖片