本文主要記錄如何使用最簡單的人工神經網絡求解二分類問題。
先決條件:
- 問題是線性可分的。
- 訓練數據集是線性可分的。
基礎知識
需要我們分類的樣本是什麼樣子呢?來看看下面這幅最簡單的二分類圖
事物特徵
上圖中,x1 與 x2 代表了事物的兩個屬性,例如在老虎和獅子的分類問題中,x1 和 x2 可以分別代表 “前掌寬度” 和 “體型大小” 兩個屬性(這裏只是舉例,實際情況可能和示例圖有差別),這些屬性可以用來區分老虎和獅子,我們稱他們爲特徵值。如何提取事物特徵爲特徵工程領域問題,這裏不做過多闡述。
線性可分
線性可分定義非常簡單:對於樣本集,存在一條直線、一個平面或一個超平面可以將兩類事物分割開來,則稱線性可分。如上圖所示,class1 與 class2 可以使用一條直線分割開來。
分割面公式
對於一條直線,我們需要一個點來分割兩類;對於二維空間,我們需要一條直線來分割;對於三維空間,我們需要一個二維平面來分割;對於四維空間,我們需要一個三維的超平面來分割;依此類推,更高維度的空間,都需要找一個比它低一維的超平面才能進行分割,線性可分的分類問題即是如此,我們需要找到一個超平面來將不同的分類分隔開,那麼這個平面的方程爲:Ax+By+Cz+D=0(平面方程的一般式),其中A、B、C、D均爲常量且不同時爲0,類比可得,更高維的超平面將是w1x+w2x+…+wnx+b=0。
對於上圖問題,ax1+bx2+c=0 則是將 class1 與 class2 分割開來的一個超平面。
簡單人工神經網絡(ANN)模型
公式
這個公式其實就是分割面公式。
輸入 X
上圖則爲一個最簡的ANN模型,其中 x1~x3 爲輸入 ,這些輸入分別代表了問題的某個特徵,例如,在獅子和老虎的分類問題中,x1…xn 可以分別代表體毛程度、體重、紋路、奔跑速度、咬合力…,這些變量我們統一定義爲 X 並且都爲數值,因爲只有有了相應的數值,才能量化特徵。一般來說,這些特性都是足夠用於區別不同分類,若某個特徵在這個分類問題中散亂分佈,例如站立時的形狀輪廓
對於上兩幅輪廓圖,我們可能並不能找到合適的方法來比較準確的區分獅子和老虎,這種特徵則爲噪聲。
權重 W 與偏置 b
其中的 w1~wn 爲權重變量,統一定義爲 W,可以簡單理解爲其對應特徵的影響程度的描述,就如同我們判斷老虎和獅子時通常是通過對比體毛和紋路來得到答案,我們更關注與這兩個特徵,因此我們大概可以判定,W 在這兩個特徵上的絕對值將會偏大。
其中的 b 爲偏置,它是爲了更好的擬合數據,有興趣的可自行百度。
輸出 y
y 爲神經網絡最終得到的結果。在二分類問題中,我們通常將 y 定義爲:當 y>=0時,y=1,否則 y=-1。理解起來並不困難,因爲在二分類問題中,目標只可能爲兩個類別中的某一類,爲了方便觀察,以 0 爲界限,最終得到正數則爲類別 1,得到負數則爲類別 2。
訓練樣本
我們學習知識通常需要一個目標以及一些已知的知識儲備,機器學習一樣如此,需要輸入訓練樣本,機器才能知道需要學習的東西。我們將訓練樣本定義爲(X,Y),其中 X 爲多個特徵值(x1~xn),Y 則爲這個樣本的類別(或標籤),在二分類問題中,Y 一般爲 1 或 -1 來表示不同類別。
需要學習的參數
模型中,X 需要我們手動輸入,因此 X 是爲已知的,y 爲最終計算得到的結果,因此 y 也並不是需要學習的,權重 W 與偏置 b 沒有從任何地方可以得到,因此它們就是需要學習的參數。
學習方法
現在我們知道了模型的計算過程、結果定義以及需要學習的參數,那麼通過什麼手段來學習 W 和 b 呢?這裏選擇使用 感知器算法 對模型中的 W、b 進行更新。它是一種最簡單的學習算法,是一種二元線性分類器,具體公式爲
學習步驟:
- 在訓練樣本集中選擇一個樣本(X,Y),將這個樣本輸入神經網絡模型
- 當樣本類別 Y=1,且計算結果 y<=0 時,更新權重 W 與 b;當樣本類別 Y=-1,且計算結果 y>=0 時,更新權重 W 與 b
- 重複步驟 1,2,直到對於訓練樣本集中所有的樣本都不滿足上述公式,則爲模型訓練完成
模型測試
準備一組測試樣本集,測試樣本與訓練樣本格式一致。將測試樣本的 X 輸入到訓練好的模型中,模型計算得到 y,如 y 與 樣本類別 Y 符號一致(同爲正數或同爲負數),則分類成功,反之則爲失敗,成功率越高則說明模型訓練的越好。當然對於相同的訓練樣本集,反覆的訓練最終得到的模型都是一樣的。
Java實現
模型與學習算法都確定好了之後便可以進行編碼實現了
- 首先定義樣本數據類型 Data.java
public class Data {
// 特徵值
private double[] x;
// 標籤
private int y;
public Data(double[] x, int y) {
this.x = x;
if (y >= 0) {
this.y = 1;
} else {
this.y = -1;
}
}
public double[] getX() {
return x;
}
public void setX(double[] x) {
this.x = x;
}
public int getY() {
return y;
}
public void setY(int y) {
this.y = y;
}
}
- 簡單ANN模型,其中 a=0.5 是爲了防止步長過大導致權重過大超出 double 範圍
public class SimpleNeuralNetworks {
// 學習率
private double a = 0.5;
// 特徵維度
private int dimension_w;
// 權重
private double[] w;
// 偏置
private double b;
// 初始化特徵維度值
public void init(int i) {
dimension_w = i;
w = new double[i];
}
// 測試
public double test(double[] x) {
return train(x);
}
// 輸入樣本集進行訓練
public void train(List<Data> datas) {
int i = 0;
while (i < datas.size()) {
Data data = datas.get(i);
double res = train(data.getX());
if (res * data.getY() <= 0) {
updateW(data.getY(), data.getX());
i = 0;
} else {
i++;
}
}
}
// 更新權重,感知器算法
private void updateW(double label, double[] x) {
double m = label * a;
for (int i = 0; i < dimension_w; i++) {
w[i] += m * x[i];
}
b += m;
System.out.println("權重更新後---");
for (int i = 0; i < dimension_w; i++) {
System.out.print(w[i] + "\t");
}
System.out.print(b);
System.out.println();
}
// 計算樣本特徵
private double train(double[] x) {
double res = 0;
for (int i = 0; i < dimension_w; i++) {
res += w[i] * x[i];
}
return res + b;
}
}
- 有了數據類型和ANN模型了,就差數據集了。本文訓練使用的數據集下載自 UCI 的Iris Data Set,其數據集分三類,每一類均有50條樣本數據,這裏只取其中兩類做爲訓練與測試,每類選擇40條樣本數據作爲訓練樣本,剩餘10條作爲測試樣本。下載鏈接:https://download.csdn.net/download/lcsy000/12058888 。需要注意的是,此樣本中的 y 均爲英文,因此需要我們將其轉爲 1 和 -1,選擇一類爲 1,另一類則爲 -1。我是這樣做的處理
- 現在樣本集也有了,我們需要讀取它們到內存中,因此定義數據讀取器 DataReader.java
public final class DataReader {
private DataReader() {}
// 將文件讀取爲數據集
public static List<Data> read(String path, String splitStr) throws IOException {
List<Data> dataSet = new ArrayList<>();
BufferedReader br = new BufferedReader(new FileReader(path));
String line;
while ((line = br.readLine()) != null) {
if (line.trim().isEmpty()) {
continue;
}
String[] split = line.split(splitStr);
double[] x = new double[split.length - 1];
for (int i = 0; i < x.length; i++) {
x[i] = Double.valueOf(split[i]);
}
dataSet.add(new Data(x, Integer.parseInt(split[split.length - 1])));
}
return dataSet;
}
}
- 最後編寫主程序,並進行訓練
public class Main {
public static void main(String[] args) throws IOException {
List<Data> dataSets = DataReader.read("F:/neural-networks/train.data", ",");
SimpleNeuralNetworks networks = new SimpleNeuralNetworks();
networks.init(dataSets.get(0).getX().length);
networks.train(dataSets);
List<Data> testSets = DataReader.read("F:/neural-networks/test.data", ",");
testSets.forEach(test -> {
double res = networks.test(test.getX());
System.out.println(res);
if (res * test.getY() > 0) {
System.out.println("yes");
} else {
System.out.println("no");
}
});
}
}
- 運行結果
在幾次的權重更新之後便學習完畢,並且使用剩餘的 20 條數據進行測試得到 100% 的成功率。
到此我們就基本學會了如何利用最簡單的神經網絡模型與感知器算法去解決線性可分的二分類問題,但二分類是遠遠不夠的,因爲現實生活中更多的問題都爲多分類問題,那麼不妨思考一下,如何利用已知的模型與算法處理線性可分的多分類問題?