Python:批量梯度下降實現一元線性迴歸

# _*_ coding : utf-8 _*_
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


class LinearRegression(object):
    def __init__(self,learning_rate=0.01,max_iter=100,seed=None):
        np.random.seed(seed)
        self.lr = learning_rate
        self.max_iter = max_iter
        self.w = np.random.normal(1,0.1)
        self.b = np.random.normal(1,0.1)
        self.loss_arr = []
    def fit(self,x,y):
        self.x = x
        self.y = y
        for i in range(self.max_iter):
            self._train_step()
            self.loss_arr.append(self.loss())
    def model(self,x,w,b):
        return x * w + b
    def predict(self,x=None):
        if x is None:
            x = self.x
        y_pred = self.model(x,self.w,self.b)
        return y_pred
    def loss(self,y_true=None,y_pred=None):
        if y_true is None or y_pred is None:
            y_true = self.y
            y_pred = self.predict(self.x)
        return np.mean((y_true - y_pred)**2)
    def _calc_gradient(self):
        d_w = np.mean((self.x * self.w + self.b - self.y) * self.x)
        d_b = np.mean((self.x * self.w + self.b - self.y))
        return d_w,d_b
    def _train_step(self):
        d_w,d_b = self._calc_gradient()
        self.w = self.w - self.lr * d_w
        self.b = self.b - self.lr * d_b
        return self.w,self.b

def generate_data():
    np.random.seed(272)
    data_size = 100
    X = np.random.uniform(low=1.0,high=10.0,size=data_size)
    y = X * 20 + 10 + np.random.normal(loc=0.0,scale=10.0,size=data_size)
    return pd.DataFrame({"X":X,"y":y})

if __name__ == '__main__':
    data = np.array(generate_data())
    x = data[:,0]
    y = data[:,1]
    regr = LinearRegression(learning_rate=0.01,max_iter=10,seed=111)
    regr.fit(x,y)


    def show_data(x, y, w=None, b=None):
        plt.scatter(x, y, marker='.')
        if w is not None and b is not None:
            plt.plot(x, w * x + b, c='red')
        plt.show()

    show_data(x, y, regr.w, regr.b)
    plt.scatter(np.arange(len(regr.loss_arr)), regr.loss_arr, marker='o', c='green')
    plt.show()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章