Logistic迴歸梯度上升分類法

#coding=utf-8
from numpy import *
import re
def load_data():
    """加載數據"""
    data=[];label=[]
    #附加兩個表
    open_file=open('test.txt')
    for line in open_file.readlines():
        line_arr=re.split(r'(\d*)',line.strip())
        data.append([1.0,float(int(line_arr[1])),float(int(line_arr[3]))])
        label.append(int(line_arr[5]))
    #這個for就是附加表用的,裏面的正則是因爲數據格式問題。也就是對數據進行篩選
    return data,label
def sigmoid(inx):
    """sigmoid函數"""
    return 1.0/(1+exp(-inx))
def grad_ascent(data,label):
    """迴歸梯度上升"""
    data_matrix=mat(data)
    #把原始數據矩陣化
    label_matrix=mat(label).T
    #矩陣化後轉置
    m,n=shape(data_matrix)
    #m代表行,n代表列
    alpha=0.001
    #步幅
    max_cycles=5000
    #迭代次數,本列中迭代次數在閾值內次數越多越精確
    weights=ones((n,1))
    #weights[0]步長,weights[1]&weights[2]決定了擬合直線的斜率
    for k in range(max_cycles):
        h=sigmoid(data_matrix*weights)
        #data_matrix*weights就相當於求data_matrix中每行元素與weights每列元素乘積的和
        error=(label_matrix-h)
        #類別目標的差值。即元素a經sigmoid判定後爲類別1.它的sigmoid值爲0.72 。error=1-0.72        
        weights=weights+alpha*data_matrix.transpose()*error
        #w=w+α▽wf(w)梯度上升算法迭代公式。按定義看
        #梯度上升,weights[0]爲累計步長        
    weights=weights.getA()
    return weights
def best_fit(weights):
    """"""
    import matplotlib.pyplot as plt
    data,label=load_data()
    data_arr=array(data)
    #陣列化
    m=shape(data_arr)[0]
    #獲取數據行數
    xcord1=[];ycord1=[]
    xcord0=[];ycord0=[]
    #初始列表
    for i in range(m):
        if int(label[i])==1:
            xcord1.append(data_arr[i,1])
            ycord1.append(data_arr[i,2])
            #獲取1類座標
        else:
            xcord0.append(data_arr[i,1])
            ycord0.append(data_arr[i,2])
            #獲取0類座標
    fig=plt.figure()
    #創建圖形
    ax=fig.add_subplot(111)
    #創建一個子圖
    ax.scatter(xcord1,ycord1,s=30,c='red',marker='s')
    ax.scatter(xcord0,ycord0,s=30,c='green')
    #創建散點圖
    x=arange(8.0,10.0,0.1)
    #擬合直線x軸。前兩個參數是範圍,後一個是精度。
    #關於範圍的選定應該看兩類數據x軸的重合範圍。
    #如樣本數據是隻在9處重合,那取(8,10)這個區間。
    y=(-weights[0]-weights[1]*x)/weights[2]
    ax.plot(x,y)
    #創建最佳擬合直線
    plt.xlabel('X1')
    plt.ylabel('X2')
    plt.show()
    






樣本數據:

[2, 4, '1']
[5, 8, '1']
[8, 5, '1']
[9, 9, '1']
[4, 1, '1']
[0, 0, '1']
[5, 8, '1']
[9, 3, '1']
[1, 8, '1']
[7, 3, '1']
[9, 3, '1']
[3, 8, '1']
[4, 6, '1']
[9, 7, '1']
[7, 1, '1']
[5, 2, '1']
[9, 6, '1']
[6, 9, '1']
[9, 8, '1']
[7, 0, '1']
[4, 5, '1']
[9, 8, '1']
[0, 4, '1']
[4, 3, '1']
[6, 0, '1']
[9, 9, '1']
[0, 3, '1']
[9, 8, '1']
[1, 7, '1']
[5, 8, '1']
[7, 8, '1']
[1, 5, '1']
[0, 7, '1']
[1, 9, '1']
[7, 8, '1']
[2, 5, '1']
[7, 4, '1']
[2, 1, '1']
[6, 1, '1']
[0, 1, '1']
[2, 4, '1']
[6, 0, '1']
[8, 0, '1']
[4, 9, '1']
[8, 3, '1']
[9, 8, '1']
[8, 9, '1']
[5, 9, '1']
[9, 6, '1']
[4, 2, '1']
[8, 7, '1']
[1, 9, '1']
[3, 8, '1']
[0, 1, '1']
[1, 1, '1']
[0, 9, '1']
[0, 6, '1']
[1, 5, '1']
[2, 6, '1']
[9, 5, '1']
[5, 0, '1']
[2, 4, '1']
[5, 9, '1']
[9, 5, '1']
[6, 3, '1']
[9, 3, '1']
[3, 6, '1']
[8, 6, '1']
[7, 7, '1']
[0, 0, '1']
[5, 4, '1']
[2, 9, '1']
[5, 7, '1']
[3, 9, '1']
[6, 9, '1']
[8, 2, '1']
[8, 3, '1']
[8, 0, '1']
[2, 4, '1']
[9, 2, '1']
[0, 3, '1']
[6, 8, '1']
[5, 4, '1']
[5, 0, '1']
[5, 3, '1']
[7, 6, '1']
[0, 4, '1']
[3, 9, '1']
[7, 5, '1']
[8, 3, '1']
[9, 7, '1']
[8, 3, '1']
[3, 5, '1']
[2, 6, '1']
[1, 9, '1']
[6, 2, '1']
[3, 5, '1']
[9, 7, '1']
[5, 6, '1']
[7, 2, '1']
[11, 0, '0']
[13, 7, '0']
[16, 5, '0']
[11, 0, '0']
[17, 6, '0']
[9, 5, '0']
[15, 1, '0']
[13, 7, '0']
[12, 6, '0']
[9, 5, '0']
[17, 4, '0']
[10, 8, '0']
[10, 8, '0']
[9, 5, '0']
[13, 2, '0']
[13, 6, '0']
[9, 0, '0']
[11, 9, '0']
[17, 2, '0']
[9, 7, '0']
[16, 4, '0']
[12, 1, '0']
[10, 8, '0']
[10, 1, '0']
[17, 7, '0']
[12, 0, '0']
[16, 5, '0']
[18, 2, '0']
[15, 6, '0']
[9, 5, '0']
[13, 4, '0']
[13, 2, '0']
[10, 2, '0']
[17, 7, '0']
[16, 1, '0']
[15, 0, '0']
[9, 4, '0']
[16, 7, '0']
[13, 1, '0']
[17, 0, '0']
[18, 4, '0']
[12, 3, '0']
[10, 7, '0']
[14, 6, '0']
[9, 5, '0']
[11, 9, '0']
[12, 4, '0']
[17, 8, '0']
[10, 2, '0']
[12, 5, '0']
[13, 0, '0']
[12, 2, '0']
[11, 1, '0']
[14, 1, '0']
[17, 0, '0']
[18, 3, '0']
[10, 5, '0']
[18, 2, '0']
[12, 4, '0']
[15, 8, '0']
[17, 9, '0']
[18, 5, '0']
[14, 9, '0']
[16, 9, '0']
[18, 5, '0']
[9, 1, '0']
[14, 4, '0']
[13, 2, '0']
[12, 9, '0']
[16, 8, '0']
[15, 4, '0']
[12, 0, '0']
[16, 9, '0']
[14, 3, '0']
[12, 9, '0']
[17, 5, '0']
[11, 4, '0']
[13, 6, '0']
[16, 3, '0']
[16, 2, '0']
[11, 3, '0']
[11, 1, '0']
[17, 9, '0']
[18, 2, '0']
[11, 8, '0']
[14, 3, '0']
[11, 0, '0']
[18, 6, '0']
[12, 6, '0']
[10, 0, '0']
[14, 0, '0']
[16, 5, '0']
[12, 7, '0']
[15, 0, '0']
[15, 1, '0']
[18, 9, '0']
[9, 0, '0']
[18, 0, '0']
[18, 6, '0']
[9, 3, '0']

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章