統計學習方法之感知機 python代碼實現

感知機是二分類線性模型,其輸入爲實例的特徵向量,輸出爲實例的類別,取+1和-1。

根據《統計學習方法》第2章,用python實現感知機。

import numpy as np
import matplotlib.pyplot as plt


class Perceptron:

    def __init__(self, dim):
        """ 初始化權重 w b 以及x特徵的維度dim"""
        self.w = np.random.random((dim,))
        self.b = np.random.rand()
        self.dim = dim

    def f(self, x):
        """ w * x + b"""
        return np.dot(self.w, x.T) + self.b

    def gw(self, x, y):
        """ w梯度 """
        return x * y

    def gb(self,  y):
        """ b梯度"""
        return y

    def train_2(self, xs, ys, alpha=1, max_step=1000):
        """
        學習算法的對偶形式
        """

        n = len(xs)
        matrix = np.zeros((n,n),dtype=int)
        for i in range(n):
            for j in range(n):
                matrix[i][j] = np.dot(xs[i],xs[j])
        print("matrix:", matrix)
        a, b = [0]*n, 0
        step = 1
        error = float("inf")
        while error and step < max_step:
            print("step: {}, error: {}".format(step, error))
            error = 0
            for i,(x, y) in enumerate(zip(xs, ys)):
                sum_ = b
                for j in range(n):
                    sum_ += a[j]*ys[j]*matrix[j][i]
                if y*sum_ <= 0:
                    a[i] = a[i] + alpha
                    b += alpha * y
                    error += 1
            step += 1
        self.w = sum(a[i]*xs[i]*ys[i] for i in range(n))
        self.b = b
        if step >= max_step:
            print("max_step limit")
        else:
            print("successfully learn")
        print("w: {} b: {}".format(self.w, self.b))

    def train(self, xs, ys, alpha=0.1, max_step=1000):
        """
        學習算法
        :param xs: 輸入樣本特徵
        :param ys: 樣本分類
        :param alpha: 學習率
        :param max_step: 最大迭代倫次
        :return:
        """
        # 輸入數據維度要等於指定維度(畫圖只能是二維)
        if xs.shape[1] != self.dim:
            raise ValueError("x sample must {} dim".format(self.dim))

        # 初始化圖
        fig = self.figure_init(xs, ys)

        error = float("inf")
        step = 1
        while error and step < max_step:
            print("step: {}, error: {}".format(step, error))
            error = 0
            for x, y in zip(xs, ys):
                if self.f(x) * y <= 0:
                    self.w += alpha * self.gw(x, y)
                    self.b += alpha * self.gb(y)
                    self.figure_update(fig)  # 更新圖
                    error += 1
            step += 1
        if step >= max_step:
            print("max_step limit")
        else:
            print("successfully learn")
        print("w: {} b: {}".format(self.w, self.b))

        # 關閉交互模式
        plt.ioff()

        # 圖形顯示
        plt.show()

    def figure_init(self, xs, ys):

        fig = plt.figure(figsize=(8, 6), dpi=80)
        ax = fig.add_subplot(1, 1, 1)
        # 設定標題等
        plt.title("Perceptron")
        plt.grid(True)

        # 設置X軸
        plt.xlabel("x(1)")
        plt.xlim(-10, 10)

        # 設置Y軸
        plt.ylabel("x(2)")
        plt.ylim(-10, 10)

        # 畫點
        for x_, y_ in zip(xs, ys):
            if y_ == 1:
                ax.plot(x_[0], x_[1], "ro")
            else:
                ax.plot(x_[0], x_[1], "go")

        # 生成超平面 w1 * x1 + w2 * x2 + b = 0,隨機取兩個x1,計算x2連起來
        x1 = np.linspace(-10, 10, 2, endpoint=True)
        x2 = (- self.b - self.w[0] * x1) / self.w[1]

        # 畫直線
        lines = ax.plot(x1, x2, "b-", linewidth=2.0, label="hyperplane")

        # 設置圖例位置,loc可以爲[upper, lower, left, right, center]
        ax.legend(loc="upper left", shadow=True)

        # 暫停
        plt.pause(0.5)

        ax.lines.remove(lines[0])
        # 打開交互模式
        plt.ion()

        return ax

    def figure_update(self, ax):
        """
        畫圖程序,每更新一次權重,調用一次
        """
        # 生成超平面 w1 * x1 + w2 * x2 + b = 0,隨機取兩個x1,計算x2連起來
        x1 = np.linspace(-10, 10, 2, endpoint=True)
        x2 = (- self.b - self.w[0]*x1)/self.w[1]

        # 更新直線
        lines = ax.plot(x1, x2, "b-", linewidth=2.0, label="hyperplane")
        # 暫停
        plt.pause(0.5)

        # 刪掉直線
        ax.lines.remove(lines[0])


p = Perceptron(2)
x_data = np.array([[3, 3], [4, 3], [1, 1]])
y_data = np.array([1, 1, -1])
p.train(x_data, y_data)
# p.train_2(x_data, y_data)
其中,train函數是學習算法的原始形式,train_2是學習算法的對偶形式。其中error代表誤分類的個數,簡單的邏輯就是,當誤分類的個數爲0或者超出最大學習輪次則停止學習。
另外,figrue_init 和 figure_update函數則是在訓練過程的圖形展示。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章