應用Tensorflow2.0的Eager模式快速構建神經網絡

TensorFlow是開發深度學習算法的主流框架,近來隨着keras和pytorch等框架的崛起,它受到了不小挑戰,爲了應對競爭它本身也在進化,最近新出的2.0版本使得框架的應用更加簡易和容易上手,本節我們就如何使用它2.0版本提出的eager模式進行探討,在後面章節中我們將使用它來開發較爲複雜的生成型對抗性網絡。

最新流行的深度學習框架keras一大特點是接口的易用性和可理解性,它在Tensorflow的基礎上進行了深度封裝,它把很多技術細節隱藏起來,同時調整設計模式,使得基於keras的開發比Tensorflow要簡單得多。但keras對應的問題是,封裝太好雖然有利於易用性,但是不利於開發人員,特別是初學者對模型設計的深入理解,由於我們主題是學習神經網絡的設計原理,由於keras對模型設計模式的細節封裝過度,因此反而不利於學習者。爲了兼顧易用性和對設計細節的把握性,我選擇TF2.0帶來的Eager模式,這樣就能魚和熊掌兼得。

我們首先看看Eager模式和傳統模式有何區別。傳統模式一大特點是代碼首先要創建一個會話對象,深度學習網絡模型實際上是由多種運算節點構成的一張運算圖,模型運行時需要依賴會話對象對運算圖的驅動和管理,我們先看看傳統模式的基本開發流程:

import tensorflow as tf
a = tf.constant(3.0)
b = tf.placeholder(dtype = tf.float32)
c = tf.add(a,b)
sess = tf.Session() #創建會話對象
init = tf.global_variables_initializer()
sess.run(init) #初始化會話對象
feed = {
    b: 2.0
} #對變量b賦值
c_res = sess.run(c, feed) #通過會話驅動計算圖獲取計算結果
print(c_res)

從上面代碼看你會感覺有一種彆扭,placeholder用來開闢一塊內存,然後通過feed再把數值賦值到被開闢的內存中,然後再使用run驅動整個計算流程的運轉,這種設計模式與傳統編程模式的區別在於饒了一個彎,對很多TF的初學者而言,一開始要花不少精力去適應這種模式。

我們再看看eager模式下上面代碼的設計過程,首先要注意一點是,要開啓eager模式,需要在最開始處先執行如下代碼:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()

代碼執行後TF就進入eager模式,接下來我們看看如何實現前面的運算步驟:

def  add(num1, num2):
    a = tf.convert_to_tensor(num1) #將數值轉換爲TF張量,這有利於加快運算速度
    b = tf.convert_to_tensor(num2)
    c = a + b
    return c.numpy() #將張量轉換爲數值
add_res = add(3.0, 4.0)
print(add_res)

代碼運行後輸出結果7.0,可以看到eager模式的特點是省掉了傳統模式繞彎的特點,它可以像傳統編程模式那樣從上到下的方式執行所有運算步驟,不需要特別去創建一個會話對象,然後再通過會話對象驅動所有運算步驟的執行,這種設計模式就更加簡單易懂

我們看看如何使用eager模式開發一個簡單的神經網絡。類似"Hello World!",在神經網絡編程中常用與入門的練手項目叫鳶尾花識別,它的花瓣特徵明顯,不同品種對應花瓣的寬度和長度不同,因此可以通過通過神經網絡讀取花瓣信息後識別出其對應的品種,首先我們先加載相應訓練數據:

from sklearn import datasets, preprocessing, model_selection
data = datasets.load_iris() #加載數據到內存
x = preprocessing.MinMaxScaler(feature_range = (-1, 1)).fit_transform(data['data']) #將數據數值預處理到(-1,1)之間方便網絡識別
#把不同分類的品種用向量表示,例如有三個不同品種,那麼分別用(1,0,0),(0,1,0),(0,0,1)表示
y = preprocessing.OneHotEncoder(sparse = False).fit_transform(data['target'].reshape(-1, 1))
x_train, x_test, y_train, y_test = model_selection.train_test_split(x, y, test_size = 0.25, stratify = y) #將數據分成訓練集合測試集
print(len(x_train))

代碼運行後可以看到擁有訓練的數據有112條。接下來我們創建一個簡單的三層網絡:

class IrisClassifyModel(object):
    def  __init__(self, hidden_unit, output_unit):
        #這裏只構建兩層網絡,第一層是輸入數據
        self.hidden_layer = tf.keras.layers.Dense(units = hidden_unit, activation = tf.nn.tanh, use_bias = True, name="hidden_layer")
        self.output_layer = tf.keras.layers.Dense(units = output_unit, activation = None, use_bias = True, name="output_layer")
    def  __call__(self, inputs):
        return self.output_layer(self.hidden_layer(inputs))

我們用如下代碼檢測一下網絡構建的正確性:

#構造輸入數據檢驗網絡是否正常運行
model = IrisClassifyModel(10, 3)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
for x, y in tfe.Iterator(train_dataset.batch(32)):
    output = model(x)
    print(output.numpy())
    break

代碼如果正確運行並輸出相應結果,那表明網絡設計沒有太大問題。接着我們用下面代碼設計損失函數和統計網絡預測的準確性:

def  make_loss(model, inputs, labels):
    return  tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(logits = model(inputs), labels = labels))
opt = tf.train.AdamOptimizer(learning_rate = 0.01)
def train(model, x, y):
    opt.minimize(lambda:make_loss(model, x, y))
accuracy = tfe.metrics.Accuracy()
def  check_accuracy(model, x_batch, y_batch): #統計網絡判斷結果的準確性
    accuracy(tf.argmax(model(tf.constant(x_batch)), axis = 1), tf.argmax(tf.constant(y_batch), axis = 1))
    return accuracy

最後我們啓動網絡訓練流程,然後將網絡訓練的結果繪製出來:

import numpy as np
model = IrisClassifyModel(10, 3)
epochs = 50
acc_history = np.zeros(epochs)
for epoch in range(epochs):
    for (x_batch, y_batch) in tfe.Iterator(train_dataset.shuffle(1000).batch(32)):
        train(model, x_batch, y_batch)
        acc = check_accuracy(model, x_batch, y_batch)
        acc_history[epoch] = acc.result().numpy()

import matplotlib.pyplot as plt
plt.figure()
plt.plot(acc_history)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.show()

上面代碼運行後結果如下:
image.png

可以看到網絡經過訓練後準確率達到95%以上。本節的目的是爲了介紹TF2.0的eager模式,爲後面開發更復雜的網絡做技術準備。

更詳細的講解和代碼調試演示過程,請點擊鏈接

更多技術信息,包括操作系統,編譯器,面試算法,機器學習,人工智能,請關照我的公衆號:
這裏寫圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章