學過CNN的都知道im2col是非常重要的函數之一,用於將輸入的四維數據轉化成二維數據方便進行卷積運算。
代碼雖就10來行不過要理解透徹還是相當不容易的,第一次寫涉及到如此高維的數據處理,有點難以適應...
im2col
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
"""
Parameters
----------
input_data : 由(數據量, 通道, 高, 長)的4維數組構成的輸入數據
filter_h : 濾波器的高
filter_w : 濾波器的長
stride : 步幅
pad : 填充
Returns
-------
col : 2維數組
"""
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
# 這一步是因爲y還可以往下走(out_h - 1)步,所以需要計算ymax,用y:y_max:stride的方式來表達每一次的跨步
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col
col2im
def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
"""
Parameters
----------
col :
input_shape : 輸入數據的形狀(例:(10, 1, 28, 28))
filter_h :
filter_w
stride
pad
Returns
-------
"""
N, C, H, W = input_shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
# 下面將變成(N, C, filter_h, filter_w, out_h, out_w),是標準的col模式
col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)
# 下面這句話怎麼理解?
# img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))
img = np.pad(np.zeros(input_shape), [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
cnt = np.pad(np.zeros(input_shape), [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]
cnt[:, :, y:y_max:stride, x:x_max:stride] = cnt[:, :, y:y_max:stride, x:x_max:stride] + 1
res = img / cnt
return res[:, :, pad:H + pad, pad:W + pad]
demo
import numpy as np
def my_im2col(input_data, fh, fw, stride=1, pad=0):
N, C, H, W = input_data.shape
out_h = (H + 2 * pad - fh) // stride + 1
out_w = (W + 2 * pad - fw) // stride + 1
img = np.pad(input_data, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, out_h, out_w, fh, fw))
for y in range(out_h):
for x in range(out_w):
col[:, :, y, x, :, :] = img[:, :, y : y + fh : stride, x : x + fw : stride]
col = col.transpose(0, 2, 3, 1, 4, 5).reshape(N * out_h * out_w, -1)
return col
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
"""
Parameters
----------
input_data : 由(數據量, 通道, 高, 長)的4維數組構成的輸入數據
filter_h : 濾波器的高
filter_w : 濾波器的長
stride : 步幅
pad : 填充
Returns
-------
col : 2維數組
"""
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
# 這一步是因爲y還可以往下走(out_h - 1)步,所以需要計算ymax,用y:y_max:stride的方式來表達每一次的跨步
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col
def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
"""
Parameters
----------
col :
input_shape : 輸入數據的形狀(例:(10, 1, 28, 28))
filter_h :
filter_w
stride
pad
Returns
-------
"""
N, C, H, W = input_shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
# 下面將變成(N, C, filter_h, filter_w, out_h, out_w),是標準的col模式
col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)
# 下面這句話怎麼理解?
# img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))
img = np.pad(np.zeros(input_shape), [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
cnt = np.pad(np.zeros(input_shape), [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]
cnt[:, :, y:y_max:stride, x:x_max:stride] = cnt[:, :, y:y_max:stride, x:x_max:stride] + 1
res = img / cnt
return res[:, :, pad:H + pad, pad:W + pad]
# return img
data = np.arange(1, 17).reshape((1, 1, 4, 4))
print(data)
col = im2col(data, 4, 4, stride = 2, pad = 1)
print(col)
img = col2im(col, data.shape, 4, 4, stride = 2, pad = 1)
print(img)