訓練一個簡單的Tensorflow神經網絡模型

以下代碼運行於Google Colaboratory:

import tensorflow as tf
from numpy.random import RandomState

batch_size = 8

# 定義神經網絡參數
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))

# 定義輸入
x = tf.placeholder(tf.float32, shape=(None, 2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')

# 前向傳播
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)

# 定義損失函數和反向傳播算法
cross_entropy = -tf.reduce_mean(
    y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

# 通過隨機數生成一個模擬數據集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)

# 定義規則來給出樣本標籤
Y = [[int(x1+x2 < 1)] for (x1, x2) in X]

# 創建一個會話
with tf.Session() as sess:
  init_op = tf.global_variables_initializer()
  sess.run(init_op)
  print('w1:', sess.run(w1))
  print('w2:', sess.run(w2))
  
  # 設定訓練的輪數
  STEPS = 5000
  for i in range(STEPS):
    
    start = (i * batch_size) % dataset_size
    end = min(start+batch_size, dataset_size)
    
    sess.run(train_step,
            feed_dict={x: X[start:end], y_: Y[start:end]})
    
    # 每隔一段時間計算在所有數據上的交叉熵
    if i % 1000 == 0:
      total_cross_entropy = sess.run(cross_entropy,
                                     feed_dict={x: X, y_: Y})
      print("After %d trainning step(s), cross entropy on all data id %g"
            % (i, total_cross_entropy))
      
  # 訓練結束
  print('w1:', sess.run(w1))
  print('w2:', sess.run(w2))

輸出結果如下:

w1: [[-0.8113182   1.4845988   0.06532937]
 [-2.4427042   0.0992484   0.5912243 ]]
w2: [[-0.8113182 ]
 [ 1.4845988 ]
 [ 0.06532937]]
After 0 trainning step(s), cross entropy on all data id 0.0674925
After 1000 trainning step(s), cross entropy on all data id 0.0163385
After 2000 trainning step(s), cross entropy on all data id 0.00907547
After 3000 trainning step(s), cross entropy on all data id 0.00714436
After 4000 trainning step(s), cross entropy on all data id 0.00578471
w1: [[-1.9618274  2.582354   1.6820378]
 [-3.4681718  1.0698233  2.11789  ]]
w2: [[-1.8247149]
 [ 2.6854665]
 [ 1.4181951]]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章