邏輯迴歸的相關問題及java實現

本講主要說下邏輯迴歸的相關問題和具體的實現方法

1. 什麼是邏輯迴歸

邏輯迴歸是線性迴歸的一種,那麼什麼是迴歸,什麼是線性迴歸

迴歸指的是公式已知,對公式中的未知參數進行估計,注意公式必須是已知的,否則是沒有辦法進行迴歸的

線性迴歸指的是迴歸中的公式是一次的,例如z=ax+by

邏輯迴歸其實就是在線性迴歸的基礎上套了一個sigmoid函數,具體的樣子如下


2. 正則化項

引入正則化項的目的是防止模型過擬合,函數對樣本的擬合有三種結果

欠擬合:直觀的理解就是在訓練集上的誤差比較大,擬合出來的函數應該是曲線,結果擬合成了一條直線

過擬合:在訓練集上的誤差很小甚至爲0,追求經驗風險最小化,模型擬合的很複雜,往往在未知的樣本集上表現的不夠好

合適的擬合:在訓練集合測試集上都表現的比較好,追求經驗風險和結構風險的均衡

解決過擬合的問題一般有兩種方法,一是減少特徵的維度,二是進行正則化。對減少特徵的維度我的理解是造成過擬合的原因是特徵太多樣本太少,所以進行特徵選擇以減少特徵會得到比較好的擬合效果,下面詳細說一下正則化。

先看一下正則化的樣子


其實就是在損失函數里加入一個正則化項,正則化項就是權重的L1或者L2範數乘以一個lamda,用來控制損失函數和正則化項的比重,直觀的理解,首先防止過擬合的目的就是防止最後訓練出來的模型過分的依賴某一個特徵,當最小化損失函數的時候,某一維度很大,擬合出來的函數值與真實的值之間的差距很小,通過正則化可以使整體的cost變大,從而避免了過分依賴某一維度的結果。當然加正則化的前提是特徵值要進行歸一化,例如有的特徵的範圍是200-500,有個特徵的範圍是0-1,這個時候就要進行歸一化,例如都化爲0-1之間。

3. 最小二乘法和最大似然法

最小二乘法,感覺名字起的不好,不能一目瞭然,有點拗口,其實就是最小平方和的意思麼,那麼爲什麼用最小二乘法呢,我們知道,我們的目的就是較少預測值和真實值之間的差值,那麼直接把差值直接加起來作爲誤差不就好了,當然不行,因爲誤差有正有負,有些誤差會抵消,那麼絕對值的和呢,聽起來也比較合理,理論上應該也可以,不過最小二乘法有個比較合理解釋,有樣本點D,然後很多候選的曲線h來分開這些點,那麼選擇哪條直線呢,我們選的應該是後驗概率最大的那條線,也就是P(h|D)最大的那條線。由貝葉斯知道p(h|D)正比於p(h)*p(D|h),先驗概率p(h)認爲是均等的,所以只要最大化p(D|h)即可,因爲樣本點D是獨立的,所以p(D|h)=p(d1|h)*p(d2|h)*......*p(dn|h )。我們認爲這些點是含有噪音的,是因爲噪音讓他偏離了一條完美的曲線,一種很合理的假設就是偏離遠大的概率越小,那麼這個偏離的概率可以用正態分佈來描述,形式化的表達爲p(dn|h)=exp(-delta^2),所以p(D|h)=exp(-(delta1^2+delta2^2......+deltan^2)),我們的目的是最大化這個概率,等價於最小化裏面的平方和,min(delta1^2+delta2^2......+deltan^2),是不是很熟悉啊

這個時候,我們看一下,最小二乘法適合做邏輯迴歸的誤差函數麼,答案是不適合,因爲最小二乘法的誤差我們假設的事符合正態分佈,而邏輯迴歸的誤差符合的是二項分佈,所以不能用最小二乘法來作爲損失函數,那麼可以用最大似然估計來做

4. java實現梯度下降法

實驗:

樣本:

-0.017612	14.053064	0
-1.395634	4.662541	1
-0.752157	6.538620	0
-1.322371	7.152853	0
0.423363	11.054677	0
0.406704	7.067335	1
0.667394	12.741452	0
-2.460150	6.866805	1
0.569411	9.548755	0
-0.026632	10.427743	0
0.850433	6.920334	1
1.347183	13.175500	0
1.176813	3.167020	1
-1.781871	9.097953	0
-0.566606	5.749003	1
0.931635	1.589505	1
-0.024205	6.151823	1
-0.036453	2.690988	1
-0.196949	0.444165	1
1.014459	5.754399	1
1.985298	3.230619	1
-1.693453	-0.557540	1
-0.576525	11.778922	0
-0.346811	-1.678730	1
-2.124484	2.672471	1
1.217916	9.597015	0
-0.733928	9.098687	0
-3.642001	-1.618087	1
0.315985	3.523953	1
1.416614	9.619232	0
-0.386323	3.989286	1
0.556921	8.294984	1
1.224863	11.587360	0
-1.347803	-2.406051	1
1.196604	4.951851	1
0.275221	9.543647	0
0.470575	9.332488	0
-1.889567	9.542662	0
-1.527893	12.150579	0
-1.185247	11.309318	0
-0.445678	3.297303	1
1.042222	6.105155	1
-0.618787	10.320986	0
1.152083	0.548467	1
0.828534	2.676045	1
-1.237728	10.549033	0
-0.683565	-2.166125	1
0.229456	5.921938	1
-0.959885	11.555336	0
0.492911	10.993324	0
0.184992	8.721488	0
-0.355715	10.325976	0
-0.397822	8.058397	0
0.824839	13.730343	0
1.507278	5.027866	1
0.099671	6.835839	1
-0.344008	10.717485	0
1.785928	7.718645	1
-0.918801	11.560217	0
-0.364009	4.747300	1
-0.841722	4.119083	1
0.490426	1.960539	1
-0.007194	9.075792	0
0.356107	12.447863	0
0.342578	12.281162	0
-0.810823	-1.466018	1
2.530777	6.476801	1
1.296683	11.607559	0
0.475487	12.040035	0
-0.783277	11.009725	0
0.074798	11.023650	0
-1.337472	0.468339	1
-0.102781	13.763651	0
-0.147324	2.874846	1
0.518389	9.887035	0
1.015399	7.571882	0
-1.658086	-0.027255	1
1.319944	2.171228	1
2.056216	5.019981	1
-0.851633	4.375691	1
-1.510047	6.061992	0
-1.076637	-3.181888	1
1.821096	10.283990	0
3.010150	8.401766	1
-1.099458	1.688274	1
-0.834872	-1.733869	1
-0.846637	3.849075	1
1.400102	12.628781	0
1.752842	5.468166	1
0.078557	0.059736	1
0.089392	-0.715300	1
1.825662	12.693808	0
0.197445	9.744638	0
0.126117	0.922311	1
-0.679797	1.220530	1
0.677983	2.556666	1
0.761349	10.693862	0
-2.168791	0.143632	1
1.388610	9.341997	0
0.317029	14.739025	0

主要代碼

public class LogRegression {

	public static void main(String[] args) {
		
		LogRegression lr = new LogRegression();
		Instances instances = new Instances();
		lr.train(instances, 0.01f, 200, (short)1);
	}
	
	public void train(Instances instances, float step, int maxIt, short algorithm) {
		
		float[][] datas = instances.datas;
		float[] labels = instances.labels;
		int size = datas.length;
		int dim = datas[0].length;
		float[] w = new float[dim];
		
		for(int i = 0; i < dim; i++) {
			w[i] = 1;
		}
		
		switch(algorithm){
		//批量梯度下降
		case 1: 
			for(int i = 0; i < maxIt; i++) {
				//求輸出
				float[] out = new float[size];
				for(int s = 0; s < size; s++) {
					float lire = innerProduct(w, datas[s]);
					out[s] = sigmoid(lire);
				}
				for(int d = 0; d < dim; d++) {
					float sum = 0;
					for(int s = 0; s < size; s++) {
						sum  += (labels[s] - out[s]) * datas[s][d];
					}
					w[d] = w[d] + step * sum;
				}
				System.out.println(Arrays.toString(w));
			}
			break;
		//隨機梯度下降
		case 2: 
			for(int i = 0; i < maxIt; i++) {
				for(int s = 0; s < size; s++) {
					float lire = innerProduct(w, datas[s]);
					float out = sigmoid(lire);
					float error = labels[s] - out;
					for(int d = 0; d < dim; d++) {
						w[d] += step * error * datas[s][d];
					}
				}
				System.out.println(Arrays.toString(w));
			}
			break;
		}
	}
	
	private float innerProduct(float[] w, float[] x) {
		
		float sum = 0;
		for(int i = 0; i < w.length; i++) {
			sum += w[i] * x[i];
		}
		
		return sum;
	}
	
	private float sigmoid(float src) {
		return (float) (1.0 / (1 + Math.exp(-src)));
	}
}

效果



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章