[機器學習]關於分類問題的梯度下降

在分類問題中 假設有所變化,1+e的-z次方.

這個z就相當於k*x+b*1

所以對於線性迴歸的問題,梯度下降得做出改變(改變並不大)

X數據值

1,2,3,4,5,6,7,8,9,10,100

Y數據值

0,0,0,0,1,1,1,1,1,1

代碼如下

package ojama;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.List;
import java.util.Vector;

public class GradientDescent {
	public static Double[] getTheta(List<Double[]> X, Double[] y) {
		// 初始化長度
		int m = y.length;
		// 初始化theta
		Double[] theta = new Double[X.size()];
		double a = 0.001;
		for (int i = 0; i < theta.length; i++) {
			theta[i] = 0.0;
		}
		// 迭代150000次
		for (int i = 0; i < 150000; i++) {
			// 初始化temp,做替換用
			Double[] temp = new Double[theta.length];
			for (int j = 0; j < temp.length; j++) {
				temp[j] = 0.0;
			}
			for (int j = 0; j < m; j++) {
				Double sum = 0.0;
				for (int k = 0; k < theta.length; k++) {
					// 在二元圖形中,這裏相當於k*x+b*1,三元相當於a*x+b*y+c*1,以此類推
					sum += theta[k] * X.get(k)[j];
				}
				sum = 1/(1+Math.pow(Math.E,-sum)) - y[j];
				for (int k = 0; k < theta.length; k++) {
					temp[k] += sum * X.get(k)[j];
				}
			}
			for (int j = 0; j < theta.length; j++) {
				// 一起替換 同時更新
				theta[j] -= a / m * temp[j];
			}
		}
		return theta;
	}

	public static void main(String[] args) throws IOException {
		Double[] x1 = GradientDescent.read("C:/Users/ojama/Desktop/testX.txt");
		Double[] y = GradientDescent.read("C:/Users/ojama/Desktop/testY.txt");
		int m = y.length;
		Double[] x0 = new Double[m];
		for (int i = 0; i < x0.length; i++) {
			x0[i] = 1.0;
		}
		List<Double[]> X = new Vector<Double[]>();
		X.add(x0);
		X.add(x1);
		Double[] theta = GradientDescent.getTheta(X, y);
		for (int i = 0; i < theta.length; i++) {
			System.out.println(String.format("%.2f", theta[i]));
		}
	}

	public static Double[] read(String fileName) throws IOException {
		File file = new File(fileName);
		FileReader fileReader = new FileReader(file);
		BufferedReader reader = new BufferedReader(fileReader);
		StringBuilder sb = new StringBuilder();
		String str = reader.readLine();
		while (str != null) {
			sb.append(str);
			str = reader.readLine();
		}
		reader.close();
		fileReader.close();
		String[] X0 = sb.toString().replace(" ", "").split(",");
		Double[] x0 = new Double[X0.length];
		for (int i = 0; i < x0.length; i++) {
			x0[i] = Double.parseDouble(X0[i]);
		}
		return x0;
	}
}

輸出結果


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章