認識tensorflow的幾大組成元件

在使用tensorflow實現機器學習算法的程序中,需要有哪些基本組成部分?

做機器學習的入門還是挺難的,除了需要晦澀的數學公式之外,還在代碼實現上有許多挑戰,那麼,麼一個tensorflow實現一個機器學習程序,其到底需要哪些部分?
這裏就以線性擬合爲例子,本教程參考了莫煩python的教程,看了之後作爲自己總結,把其中一些關鍵的信息拎出來。

首先,我們有一些輸入數據,需要用一根曲線去擬合它的目標,也就是訓練w與b使得y=wx+b。
那麼,我們主要的部分爲:有變量來充當這些輸入數據。然後會有優化的目標函數,從而定義損耗函數以及優化策略。
tensorflow的設計者們將這個過程比擬爲一個圖性。
輸入數據位於圖的輸入端,有時我們可能會反覆的迭代,也就是輸入數據會不斷變化。那麼,就給輸入數據一個位置,然後將自己要輸入的數據對應上即可。有點像先安裝一個插座,然後接自己的線。這個組件叫做placeholde.

tf_x = tf.placeholder(tf.float32, x.shape)     # input x
tf_y = tf.placeholder(tf.float32, y.shape)     # input y

激勵
然後我們需要定義每一層的對輸入的激勵。也就是將這個輸入數據映射爲其他值

y_relu = tf.nn.relu(x)  # x是一個placeholder。這裏相當於每一層中都是placeholder即插座在打交道,而輸入和輸出數據只在起止層交互。最直觀的就是對每一層都的參數w與b都使用placeholder。

有時,我們的輸入層數據需要進行多個數據的拼接,那就是tf.concat函數。
tf 中數據的拼接

# tf.concat([tensor1,tensor2,...,tensor n], axis) 這個是形式,拼接的數據用[]包起來,axis從0開始計數,表示維度
#進行數據的拼接,拼接的維度是通過axis來控制的
tensor_one = [[1,2,3],[2,3,4]] #原始數據爲2*3
tensor_two = [[11,12,13],[11,13,14]]
tf.concat([tensor_one,tensor_two],0) # 在第一維度進行數據拼接,4*3
tf.concat([tensor_one,tensor_two],1) # 在第二維度進行數據拼接,2*6

損耗函數
然後我們需要定義cost function,併爲其定義優化函數

loss = tf.losses.mean_squared_error(tf_y, output)   # compute cost
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5)
train_op = optimizer.minimize(loss)

順利成章的,需要對這個損耗函數進行求解,也就是梯度下降等方法,稱爲solver

# Solver
D_solver = tf.train.AdamOptimizar().minimize(loss,var_list=theta_D) #優化的參數爲theta_D

然後就可以定義Session來初始化變量,以及開始優化訓練了

# 定義Sess
sess = tf.Session()
sess.run(tf.global_variables_initializer())

with tf.Session() as sess:
    out,loss = sess.run([solver,loss],feed_dict={x:x_input,Y:Y_input}) ## feed_dict是通過{}包裹的。然後這裏的等式左邊的兩個變量是與等式右邊的[solver,loss]相匹配,相當於變量的插座定義好了,然後從中間拿到我們的計算結果

    ## 如果我們只關心loss的變化,可以用下劃線替代
    # _,loss = sess.run([solver,loss],feed_dict={x:x_input,Y:Y_input}) ##

    # 然後我們可以選取一些迭代次數進行數據的輸出或者畫圖
    if step % 100 == 0:
        # do somethign 
        pass

其他部分

除此之外,我們還要掌握一些基本的數據處理方法。
熟悉list與tuple,更強大的應該是np.array吧。
例如,數據的讀取,python讀取csv文件,獲取其shape等。

python 添加路徑

# 添加一個文件路徑,然後將結果輸出到該路徑下面
import os

if not os.path.exists('Output/'):
    os.makedirs('Output/')

數據順序的打亂

data = [[1,2,3,4],
        [4,5,6,7]]
data = np.array(data)
print(data.shape)
permutation = np.random.permutation(data.shape[1])
print(permutation)
data_p = data[:,permutation]
print(data_p)

莫煩的github上有許多de’mo例子,跟着看看實踐就好。
參考文獻2(tensorflow白皮書)中的圖1與圖2就很清晰了。
參考文獻:
[1]: https://morvanzhou.github.io/tutorials/
[2]:https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章