MNIST 數據集規則化
作爲大三的我,前一段時間搞了手寫數字識別,什麼支持向量機啊,人工神經網絡啊,knn啊,都玩過了,但僅僅是停留在人家公開的訓練數據集上,而拿來測試用的圖片也是人家的,比如mnist上的圖片是這樣的。
而我們拍的照片卻是這樣的
這樣就帶來了一個問題,測試的結果就不對了,用這樣的圖片去測試訓練好的模型,可能有時候精度連10%都達不到。
這樣,有必要將MNIST的生成過程學習一遍。
MNIST數據集
MNIST數據集是一個手寫數字的集合,包含了60000個訓練集和10000個測試集。每一個數字都是20x20的,包含在28x28的圖片裏面。這對於我們預處理是非常重要的。
預處理
#我們先預處理圖片
i = 0
for no in [8,0,4,3]:
gray = cv2.imread("own_"+str(no)+".png", cv2.IMREAD_GRAYSCALE)
gray = cv2.resize(255-gray,(28,28))
cv2.imwrite("preprocessing/proImage_"+str(no)+".png",gray)
得到的圖片是這樣的:
這相對於原始圖片已經好多了,但我們還需要改進這。
我們添加下面的代碼:
(thresh, gray) = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
這段代碼的含義,我不是學習圖像處理的,於是查了一點資料,這段代碼的含義就是圖像閾值化處理,怎麼說呢,就是圖像中有超過閾值的一律歸爲255,比如這段代碼中的閾值爲128,只要圖片中有像素超過128的就歸爲255,這樣處理出來的效果爲:
到這裏我們已經處理了這個圖片,但有一個最基本的問題就是,數字有可能不在中間,因爲MNIST數據集所有的數字都在中間。
首先,我們想將調整20x20的數字位置。
由於對圖像處理不太熟悉,這裏先貼上代碼:
import numpy as np
import cv2
import math
from scipy import ndimage
gray = cv2.imread("own_4.png", cv2.IMREAD_GRAYSCALE)
(thresh, gray) = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
gray = cv2.resize(255-gray,(28,28))
while np.sum(gray[0]) == 0:
gray = gray[1:]
while np.sum(gray[:,0]) == 0:
gray = np.delete(gray,0,1)
while np.sum(gray[-1]) == 0:
gray = gray[:-1]
while np.sum(gray[:,-1]) == 0:
gray = np.delete(gray,-1,1)
rows,cols = gray.shape
if rows>cols:
factor = 20.0/rows
rows = 20
cols = int(round(cols*factor))
gray = cv2.resize(gray,(cols,rows))
else:
factor = 20.0/cols
cols = 20
rows = int(round(rows*factor))
gray = cv2.resize(gray,(cols,rows))
colsPadding = (int(math.ceil((28-cols)/2.0)),int(math.floor((28-cols)/2.0)))
rowsPadding = (int(math.ceil((28-rows)/2.0)),int(math.floor((28-rows)/2.0)))
gray = np.lib.pad(gray,(rowsPadding,colsPadding),'constant')
def getBestShift(img):
cy,cx = ndimage.measurements.center_of_mass(img)
rows,cols = img.shape
shiftx = np.round(cols/2.0-cx).astype(int)
shifty = np.round(rows/2.0-cy).astype(int)
return shiftx,shifty
def shift(img,sx,sy):
rows,cols = img.shape
M = np.float32([[1,0,sx],[0,1,sy]])
shifted = cv2.warpAffine(img,M,(cols,rows))
return shifted
shiftx,shifty = getBestShift(gray)
shifted = shift(gray,shiftx,shifty)
gray = shifted
cv2.imwrite("test.png",gray)
這段代碼的意思就是,讓不規則的圖片變爲規則的圖片如MNIST,比如這樣的:
處理之後是這樣的:
這樣,就完成了圖片的處理。
通過對圖片的處理化,我們可以將其應用在實際應用當中。
參考資料:http://openmachin.es/blog/tensorflow-mnist
接下來學習多數字識別,如這樣的:
將其中的每一個數字都識別出來。^_^