一、問題描述
今天登龍跟大家分享下使用前饋神經網絡識別 10 種類型手寫字符的方法,不太瞭解神經網絡基礎的同學,可以查看我上一篇文章:從 0 開始機器學習- 深入淺出神經網絡基礎
我們的目標就是用一個已經訓練好的神經網絡來預測下面這 10 類手寫字符 [0 - 9]:
每個字符是一個 20 X 20 = 400
像素的圖片:
OK!我們直接開始,先來看看我們用的神經網絡的架構。
二、神經網絡架構
我們在使用神經網絡之前需要進行參數的訓練,也就是訓練權重矩陣,這篇博客就不詳細展開如何訓練了,後面單獨寫一篇反向 BP 算法的文章介紹。
不管是訓練還是預測,我們都要首先搞清楚使用的神經網絡架構是怎樣的,也就是輸入輸出層有多少節點,有多少個隱藏層,每個隱藏層有多少節點,這些很重要,因爲每層的節點數都作爲權重矩陣的行和列,在預測的時候要使用這些權重矩陣。
我們這個例子使用的的 3 層神經網絡,我來給你詳細分析下這個架構:
- 輸入層(400):輸入特徵爲一個 20 x 20 = 400 像素的字符圖像,所以有 400 個輸入單元,還有一個偏置單元沒算在內
- 隱藏層(25):隱藏層 25 個節點,同樣還有一個偏置單元沒算在內
- 輸出層(10):因爲要分類 10 個數字,所以用 10 個輸出表示類別,哪個輸出 1 表示識別爲哪個數字
結構搞清楚後,我們直接開始預測,下面我帶你解析關鍵的 Python 代碼,完整代碼見文末 Github 倉庫鏈接。
三、Python 識別手寫字符
3.1 加載權重矩陣
我們使用提前訓練好的神經網絡參數,再提醒一下訓練神經網絡就是訓練每層之間的連接權重,這些連接權重組和起來就是權重矩陣,相鄰的 2 層之間有一個權重矩陣,我們就是加載這些矩陣,然後用這些矩陣與輸入圖像的 400 個像素組成的向量一步步相乘,最終得出一個 1 X 10 的向量表示預測的數字是哪個。
加載權重的代碼如下:
# 加載已經訓練好的 3 層神經網絡參數
def load_weight(path):
data = sio.loadmat(path)
return data['Theta1'], data['Theta2']
我們來加載 2 個權重矩陣(因爲我們是 3 層神經網絡,所以只有 2 個權重矩陣哦):
# 使用已經訓練好的神經網絡參數
# 輸入層:400,隱藏層:25,輸出層:10
theta1, theta2 = load_weight('ex3weights.mat')
theta1.shape, theta2.shape
輸入的 2 個權重矩陣的維度分別是:
(25, 401), (10, 26)
這符合我們的網絡架構:
- theta1 是 25 行 401 列,行數爲隱藏層單元數,列數爲輸入層單元數 + 一個偏置單元
- theta2 是 10 行 26 列,行數爲輸出層單元數,列數爲隱藏層單元數 + 一個偏置單元
如下圖:
3.2 開始前饋預測
先加載要識別的手寫字符數據:
X, y = load_data('ex3data1.mat', transpose = False)
X.shape, y.shape
輸入的 X 是 5000 行 401 列,y 是 5000 行一列,這個數據集我上篇文章有詳細介紹過:從 0 開始機器學習 - 邏輯迴歸識別手寫字符!
先定義下每層的輸入輸出:
- :表示每層神經元的輸出,注意是經過 sigmoid 等激活函數運算後的輸出
- :表示每層神經元的輸入,
首先把輸入的特徵矩陣直接作爲第一層神經元的輸出 ,注意雖然這裏是把所有的樣本一次性輸入給神經網絡,但是因爲是矩陣運算,所以可以理解爲每次處理一個樣本,也就是特徵矩陣的一行:
a1 = X
計算第二層隱藏層的輸入 ,注意這裏對 theta1 取了轉置,是因爲矩陣要能夠相乘,必須第一個矩陣的列數等於第二個矩陣的行數:
z2 = a1 @ theta1.T
給第二層隱藏層加上偏置單元,也就是增加第一列全 1 向量:
z2 = np.insert(z2, 0, values = np.ones(z2.shape[0]), axis = 1)
計算第二層隱藏層神經元的輸出 $,使用 sigmoid
激活函數:
a2 = sigmoid(z2)
再繼續利用隱藏層的輸出作爲最後輸出層的輸入 ,這裏的轉置也是爲了做矩陣乘法:
z3 = a2 @ theta2.T
計算輸出層的輸出 :
a3 = sigmoid(z3)
這個 就是我們最後對原始手寫字符的識別結果,如下:
不過這樣不太直觀,我們再來取出每行最大值的索引,就是識別的數字:
y_pred = np.argmax(a3, axis = 1) + 1
y_pred
結果如下:
array([10, 10, 10, ..., 9, 9, 9])
第一個 10 表示原數據集第一個書寫字符識別爲 10,最後一個 9 表示原數據集最後一個手寫字符的識別結果,有了識別的結果,我們再來看看總體識別的準確度如何?
來打印下分類報告:
print(classification_report(y, y_pred))
可以看到識別的準確度還是挺高的,能達到 97% 以上!OK!今天登龍就跟大家分享這些,下期再見!
文中的完整可運行代碼鏈接:前饋神經網絡預測代碼
本文原創首發於微信公號「登龍」,分享機器學習、算法編程、Python、機器人技術等原創文章,掃碼關注回覆「1024」你懂的!