【135】TensorFlow利用神經網絡學習XOR(異或)並部署成java代碼

本文python代碼使用 python 3。
本文參考了《深度學習》第107頁,6.1 實例:學習XOR

XOR 函數也稱爲異或。輸入兩個布爾型的變量 x1 和 x2 。當 x1 和 x2 不相同的時候,返回True。當 x1 和 x2 相同的時候返回 False。爲了方便計算機處理,我用 1 表示True,0 表示False。

我創建一個CSV文件XOR_train.csv,裏面內容就是異或的輸入x1 x2 和 結果 result。

XOR_train.csv

"x1","x2","result"
0,0,0
0,1,1
1,0,1
1,1,0
1,1,0
0,0,0
1,0,1
0,1,1
1,0,1
0,1,1
0,0,0
1,1,0
0,1,1

python代碼:

import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython import display

class Blog135Train:
    # 從CSV文件中讀取數據,返回DataFrame類型的數據集合。
    def read_csv(self):
        v_dataframe = pd.read_csv("XOR_train.csv", sep=",")
        # 打亂數據集合的順序。有時候數據文件有可能是根據某種順序排列的,會影響到我們對數據的處理。
        v_dataframe = v_dataframe.reindex(np.random.permutation(v_dataframe.index))
        return v_dataframe
    
    # 預處理特徵值
    def preprocess_features(self, xor_dataframe): 
        return xor_dataframe[["x1", "x2"]].copy()
    
    # 預處理標籤
    def preprocess_lebels(self, xor_dataframe): 
        return xor_dataframe[["result"]].copy()
    
    # 獲得特徵值的矩陣
    def get_matrix(self, pa_dataframe):
        v_result = []
        # dataframe中每列的名稱。
        zc_var_col_name_arr = [e for e in pa_dataframe]
        # 遍歷dataframe中的每行。
        for row_index in pa_dataframe.index:
            zc_var_tf_row = []
            for i in range(len(zc_var_col_name_arr)):
                zc_var_tf_row.append(float(pa_dataframe.loc[row_index].values[i]))
            v_result.append(zc_var_tf_row)
        return v_result
        
    
    def train(self, pa_feature_matrix, pa_label_matrix, pa_step_num, pa_learn_rate):
        v_loss_arr = []
        with tf.Session() as sess:
            v_input = tf.constant(pa_feature_matrix)
            v_label = tf.constant(pa_label_matrix)
            v_weight_0 = tf.Variable(tf.random_normal([2, 2], 0, 0.1))
            v_weight_1 = tf.Variable(tf.random_normal([1, 2], 0, 0.1))
            v_weight_2 = tf.Variable(tf.random_normal([2, 1], 0, 0.1))
            
            # 初始化上面所有的 TF 常量和變量。
            tf.global_variables_initializer().run()
            v_middle_0 = tf.matmul(v_input, v_weight_0)
            v_middle_1 = tf.add(v_middle_0, v_weight_1)
            v_middle_2 = tf.nn.relu(v_middle_1)
            v_yhat = tf.matmul(v_middle_2, v_weight_2)
            # tf.subtract計算兩個張量相減,當然兩個張量必須形狀一樣。 即 v_yhat - v_label。
            v_yerror = tf.subtract(v_yhat, v_label)
            # 計算L2損失,也就是方差。
            v_loss = tf.nn.l2_loss(v_yerror)
    
            # 梯度下降算法。
            v_optimizer = tf.train.GradientDescentOptimizer(pa_learn_rate)
            # 注意:爲了安全起見,我們還會通過 clip_gradients_by_norm 將梯度裁剪應用到我們的優化器。
            # 梯度裁剪可確保梯度大小在訓練期間不會變得過大,梯度過大會導致梯度下降法失敗。
            v_optimizer = tf.contrib.estimator.clip_gradients_by_norm(v_optimizer, 5.0)
            v_optimizer = v_optimizer.minimize(v_loss)
            
            for _ in range(pa_step_num):
                # 重複執行梯度下降算法,更新權重數值,找到最合適的權重數值。
                sess.run(v_optimizer)
                # 每次循環都記錄下損失loss的值,並放到數組loss_arr中。
                v_loss_arr.append(v_loss.eval())
            v_weight_0_result = v_weight_0.eval()
            v_weight_1_result = v_weight_1.eval()
            v_weight_2_result = v_weight_2.eval()
            v_yhat_result = v_yhat.eval()
        return (v_loss_arr, v_weight_0_result, v_weight_1_result, v_weight_2_result, v_yhat_result)
    
    
    # 畫損失的變化圖。
    # pa_ax  Axes
    # pa_arr_train_rmse 訓練次數。
    # pa_arr_validate_rmse 損失變化的記錄
    def fn_paint_loss(self, pa_ax, pa_arr_train_rmse, pa_arr_validate_rmse):
        pa_ax.plot(range(0, len(pa_arr_train_rmse)), pa_arr_train_rmse, label="training", color="blue")
        pa_ax.plot(range(0, len(pa_arr_validate_rmse)), pa_arr_validate_rmse, label="validate", color="orange")
    
    # 取得數組的最後一項
    def get_last_in_arr(self, arr):
        v_len = len(arr)
        return arr[v_len - 1]
        
    
    def main(self):
        xor_dataframe = self.read_csv()
        display.display(xor_dataframe.describe())
        print(xor_dataframe)
        # 得到特徵值的矩陣
        v_process_features = self.preprocess_features(xor_dataframe)
        v_feature_matrix = self.get_matrix(v_process_features)
        # 得到標籤值的矩陣
        v_process_label = self.preprocess_lebels(xor_dataframe)
        v_label_matrix = self.get_matrix(v_process_label)

        # 因爲訓練中存在 v_weight_0 = tf.Variable(tf.random_normal([2, 2], 0, 0.1)) 這種隨即因素,所以可能導致模型停留在
        # 某個損失曲線低窪處(這個低窪處不是全局最低點)
        v_loss_arr = [999]
        for _ in range(9):
            (v_tmp_loss_arr, v_tmp_weight_0, v_tmp_weight_1, 
             v_tmp_weight_2, v_tmp_yhat) = self.train(v_feature_matrix, v_label_matrix, 260, 0.02)
            if self.get_last_in_arr(v_tmp_loss_arr) < self.get_last_in_arr(v_loss_arr):
                v_loss_arr = v_tmp_loss_arr
                v_weight_0 = v_tmp_weight_0
                v_weight_1 = v_tmp_weight_1
                v_weight_2 = v_tmp_weight_2
                v_yhat = v_tmp_yhat
        
        
        print("Training finished:")
        print(v_yhat)
        print(v_weight_0)
        print(v_weight_1)
        print(v_weight_2)
            
        
        print("loss: ")
        v_len = len(v_loss_arr)
        v_index = 0
        for _ in range(9):
            v_index = v_index + 1
            print(v_loss_arr[v_len - v_index])
        
        # 畫出損失變化曲線
        fig = plt.figure()
        fig.set_size_inches(5,5)
        self.fn_paint_loss(fig.add_subplot(1,1,1), v_loss_arr, v_loss_arr)
        plt.show()
    

blog135Train = Blog135Train()
blog135Train.main()

運行結果:

Training finished:
yhat:
[[0.00164479]
[0. ]
[0. ]
[0.00164479]
[0.9994182 ]
[0.00164479]
[0. ]
[1. ]
[1. ]
[0.9994182 ]
[0.9994182 ]
[0.9994182 ]
[1. ]]
weight_0:
[[ 0.85549915 -0.85027605]
[-0.854509 0.84682876]]
weight_1:
[[-0.0071719 0.00139597]]
weight_2:
[[1.1787903]
[1.178247 ]]

1.png

我們可以通過 train 函數得知計算過程是這樣的:

2.png

既然我們已經知道了權重和計算過程,那麼我們就能夠輕而易舉地把訓練好的模型遷移到其他語言上。我以java爲例作個DEMO。

java項目名稱是 testXOR ,用Maven管理。

下載 jama 的jar包,地址是: https://math.nist.gov/javanumerics/jama/
本地安裝 jama

mvn install:install-file -Dfile=Jama-1.0.3.jar -DgroupId=jama -DartifactId=jama -Dversion=1.0.3 -Dpackaging=jar

pom.xml

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<groupId>zhangchao</groupId>
	<artifactId>testXOR</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	
	<properties>
		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
		<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
		<java.version>1.8</java.version>
		<maven.compiler.source>1.8</maven.compiler.source>
		<maven.compiler.target>1.8</maven.compiler.target>
	</properties>
	<dependencies>
		<dependency>
			<groupId>jama</groupId>
			<artifactId>jama</artifactId>
			<version>1.0.3</version>
		</dependency>
	</dependencies>
	<build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-jar-plugin</artifactId>
                <configuration>
                    <archive>
                        <manifest>
                            <mainClass>testXOR.Main</mainClass> <!-- 你的主類名 -->
                        </manifest>
                    </archive>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

Main.java

package testXOR;

import Jama.Matrix;

public class Main {
	
	final double[][] weight_0 = {
			{ 0.85549915, -0.85027605},
			{-0.854509  ,  0.84682876}
	};
	
	final double[][] weight_1 = {
			{-0.0071719, 0.00139597}
	};
	
	final double[][] weight_2 = {
			{1.1787903},
			{1.178247}
	};
	
	/**
	 * relu 函數
	 * @param paMatrix
	 * @return
	 */
	public Matrix relu(Matrix paMatrix){
		Matrix matrix = paMatrix.copy();
		int rowDimension = matrix.getRowDimension();
		int columnDimension = matrix.getColumnDimension();
		for (int i = 0; i < rowDimension; i++) {
			for (int j = 0; j < columnDimension; j++) {
				double v = matrix.get(i, j);
				matrix.set(i, j, Math.max(0, v));
			}
		}
		return matrix;
	}
	
	/**
	 * 測試輸入值和返回結果。正常情況按照異或的邏輯返回。
	 * @param x1  1.0 或 0.0
	 * @param x2  1.0 或 0.0
	 * @return  返回 1.0 或 0.0
	 */
	public double test(double x1, double x2){
		double r = 0.0;
		double[][] inputArr = {{x1, x2}};
		Matrix inputMatrix = new Matrix(inputArr);
		Matrix weight_0_matrix = new Matrix(weight_0);
		Matrix weight_1_matrix = new Matrix(weight_1);
		Matrix weight_2_matrix = new Matrix(weight_2);
		
		Matrix middle_0 = inputMatrix.times(weight_0_matrix);
		Matrix middle_1 = middle_0.plus(weight_1_matrix);
		Matrix middle_2 = relu(middle_1);
		Matrix yhat = middle_2.times(weight_2_matrix);
		
		if( yhat.get(0, 0) > 0.5 ){
			r = 1.0;
		}
		return r;
	}
	
	public static void main(String[] args){
		double yhat = -1;
		yhat = new Main().test(0, 0);
		System.out.println(yhat);
		
		yhat = new Main().test(1, 0);
		System.out.println(yhat);
		
		yhat = new Main().test(0, 1);
		System.out.println(yhat);
		
		yhat = new Main().test(1, 1);
		System.out.println(yhat);
	}
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章