簡單的tensorflow:2_兩層網絡and more

源碼來自:
https://github.com/nlintz/TensorFlow-Tutorials
感謝作者無私提供

初學者可以先看懂代碼,然後一定要獨立實現一次。



已經看過上一篇文章簡單的tensorflow:1_logistic regression
的同學,會發現這次的代碼僅僅改動了一點點,就是model函數由一層變成了兩層。其他部分的講解見上一篇文章。

def model(X, w_1, w_2):
   h = tf.nn.sigmoid(tf.matmul(X, w_1))
   return tf.matmul(h, w_2)

來看看網絡的變化,這次我們需要w_1、w_2兩個參數,因爲有兩層網絡,每個w控制一層。輸入層輸入X,進入到第一層隱藏層,進行了一個X與w_1的矩陣乘法,然後套了一個sigmoid激活函數,得到隱層輸出h,然後h作爲下一層的輸入與w_2進行矩陣乘法,從而得到模型的輸出。

最後看這個新的網絡的訓練情況:
0 0.5943
1 0.8026
2 0.8536
3 0.8762
4 0.8893
5 0.8948
6 0.8984
7 0.9016
8 0.9052
9 0.9075
10 0.9089
11 0.9111
12 0.9129
13 0.914
14 0.9152
15 0.916
16 0.9172
17 0.9183
18 0.9194
19 0.9206
20 0.9217
21 0.9222
22 0.9228
23 0.9236
24 0.9248
25 0.925
26 0.9254
27 0.9263
28 0.9267
29 0.9275
30 0.9277
31 0.9287
32 0.929
33 0.9295
34 0.9305
35 0.9313
36 0.9319
37 0.9326
38 0.9341
39 0.9351
40 0.9357
41 0.9363
42 0.9372
43 0.9377
44 0.9387
45 0.939
46 0.9398
47 0.9401
48 0.941
49 0.9412
50 0.9423
51 0.9427
52 0.9436
53 0.9438
54 0.9443
55 0.9448
56 0.945
57 0.9455
58 0.9457
59 0.9464
60 0.9472
61 0.9477
62 0.9482
63 0.9484
64 0.9489
65 0.949
66 0.9495
67 0.9504
68 0.9507
69 0.951
70 0.9512
71 0.9514
72 0.9517
73 0.9526
74 0.9533
75 0.9538
76 0.9541
77 0.9545
78 0.9549
79 0.9553
80 0.9557
81 0.9559
82 0.9564
83 0.9568
84 0.9569
85 0.9572
86 0.9574
87 0.9578
88 0.9579
89 0.9581
90 0.9585
91 0.9588
92 0.9594
93 0.9595
94 0.9594
95 0.9599
96 0.9602
97 0.9603
98 0.9609
99 0.9613

與上一篇文章中單層網絡最後的92.35相比,僅僅多加了一層網絡,效果就大幅提升到了0.96以上,是不是很神奇,是不是對更復雜的網絡更加期待了呢?敬請關注我的簡單的tensorflow系列後續文章。

就在這篇裏更新吧,分割線


上一份代碼僅僅加深了一層網絡就有可喜的提升,這一次,再複雜一點點。

初學者還是先看,然後再獨立實現。


仍然只是模型發生了一點變化

def model(X, w_1, w_2, w_3, p_keep_input, p_keep_hidden):
   X = tf.nn.dropout(X, keep_prob=p_keep_input)
   h1 = tf.nn.relu(tf.matmul(X, w_1))
   h1 = tf.nn.dropout(h1, keep_prob=p_keep_hidden)
   h2 = tf.nn.relu(tf.matmul(h1, w_2))
   h2 = tf.nn.dropout(h2, keep_prob=p_keep_hidden)
   return tf.matmul(h2, w_3)

可以看出,這次又多了一個w參數,也就是又多加了一層網絡,並且多了一個用於防止過擬合而存在的dropout。dropout是根據你設置的參數隨機保留相應數量的神經元,使得每次訓練並不激活全部神經元,從而防止過擬合。

運行效果如下:
0.9079
0.9499
0.9622
0.9669
0.9704
0.9731
0.9728
0.9755
0.9768
0.9762
0.9759
0.9769
0.9773
0.9775
0.9771
0.9793
0.9792
0.9798
0.9798
0.9796
0.9795
0.98
0.9803
0.9798
0.9808
0.9803
0.9805
0.9806
0.9804
0.9802
0.9807
0.9818
0.9803
0.9815
0.9808
0.9811
0.9823
0.9815
0.9809
0.9804
0.982
0.982
0.9808
0.9819
0.9815
0.9816
0.9823
0.9823
0.9818
0.9819
0.9818
0.9809
0.9823
0.9817
0.9802
0.9821
0.983
0.983
0.9819
0.9822
0.9834
0.9811
0.9816
0.9813
0.9807
0.9828
0.9819
0.9821
0.9822
0.9834
0.9811
0.9826
0.9828
0.9831
0.9819
0.9818
0.9821
0.9806
0.9817
0.9828
0.9828
0.9826
0.9826
0.9825
0.9835
0.9826
0.9822
0.982
0.9821
0.9822
0.9834
0.9825
0.9831
0.9825
0.9834
0.9828
0.9824
0.9834
0.983
0.9833
比上一份代碼加深一層網絡,增加了dropout,準確率又上升了大約2%,已經達到了98%以上,這麼簡單的一個模型,就可以做到輕鬆識別mnist數據集的手寫數據。

發佈了50 篇原創文章 · 獲贊 46 · 訪問量 12萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章