认识tensorflow的几大组成元件

在使用tensorflow实现机器学习算法的程序中,需要有哪些基本组成部分?

做机器学习的入门还是挺难的,除了需要晦涩的数学公式之外,还在代码实现上有许多挑战,那么,么一个tensorflow实现一个机器学习程序,其到底需要哪些部分?
这里就以线性拟合为例子,本教程参考了莫烦python的教程,看了之后作为自己总结,把其中一些关键的信息拎出来。

首先,我们有一些输入数据,需要用一根曲线去拟合它的目标,也就是训练w与b使得y=wx+b。
那么,我们主要的部分为:有变量来充当这些输入数据。然后会有优化的目标函数,从而定义损耗函数以及优化策略。
tensorflow的设计者们将这个过程比拟为一个图性。
输入数据位于图的输入端,有时我们可能会反复的迭代,也就是输入数据会不断变化。那么,就给输入数据一个位置,然后将自己要输入的数据对应上即可。有点像先安装一个插座,然后接自己的线。这个组件叫做placeholde.

tf_x = tf.placeholder(tf.float32, x.shape)     # input x
tf_y = tf.placeholder(tf.float32, y.shape)     # input y

激励
然后我们需要定义每一层的对输入的激励。也就是将这个输入数据映射为其他值

y_relu = tf.nn.relu(x)  # x是一个placeholder。这里相当于每一层中都是placeholder即插座在打交道,而输入和输出数据只在起止层交互。最直观的就是对每一层都的参数w与b都使用placeholder。

有时,我们的输入层数据需要进行多个数据的拼接,那就是tf.concat函数。
tf 中数据的拼接

# tf.concat([tensor1,tensor2,...,tensor n], axis) 这个是形式,拼接的数据用[]包起来,axis从0开始计数,表示维度
#进行数据的拼接,拼接的维度是通过axis来控制的
tensor_one = [[1,2,3],[2,3,4]] #原始数据为2*3
tensor_two = [[11,12,13],[11,13,14]]
tf.concat([tensor_one,tensor_two],0) # 在第一维度进行数据拼接,4*3
tf.concat([tensor_one,tensor_two],1) # 在第二维度进行数据拼接,2*6

损耗函数
然后我们需要定义cost function,并为其定义优化函数

loss = tf.losses.mean_squared_error(tf_y, output)   # compute cost
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5)
train_op = optimizer.minimize(loss)

顺利成章的,需要对这个损耗函数进行求解,也就是梯度下降等方法,称为solver

# Solver
D_solver = tf.train.AdamOptimizar().minimize(loss,var_list=theta_D) #优化的参数为theta_D

然后就可以定义Session来初始化变量,以及开始优化训练了

# 定义Sess
sess = tf.Session()
sess.run(tf.global_variables_initializer())

with tf.Session() as sess:
    out,loss = sess.run([solver,loss],feed_dict={x:x_input,Y:Y_input}) ## feed_dict是通过{}包裹的。然后这里的等式左边的两个变量是与等式右边的[solver,loss]相匹配,相当于变量的插座定义好了,然后从中间拿到我们的计算结果

    ## 如果我们只关心loss的变化,可以用下划线替代
    # _,loss = sess.run([solver,loss],feed_dict={x:x_input,Y:Y_input}) ##

    # 然后我们可以选取一些迭代次数进行数据的输出或者画图
    if step % 100 == 0:
        # do somethign 
        pass

其他部分

除此之外,我们还要掌握一些基本的数据处理方法。
熟悉list与tuple,更强大的应该是np.array吧。
例如,数据的读取,python读取csv文件,获取其shape等。

python 添加路径

# 添加一个文件路径,然后将结果输出到该路径下面
import os

if not os.path.exists('Output/'):
    os.makedirs('Output/')

数据顺序的打乱

data = [[1,2,3,4],
        [4,5,6,7]]
data = np.array(data)
print(data.shape)
permutation = np.random.permutation(data.shape[1])
print(permutation)
data_p = data[:,permutation]
print(data_p)

莫烦的github上有许多de’mo例子,跟着看看实践就好。
参考文献2(tensorflow白皮书)中的图1与图2就很清晰了。
参考文献:
[1]: https://morvanzhou.github.io/tutorials/
[2]:https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章