小批量梯度下降算法 python

# -*- coding: utf-8 -*-
"""
Created on Tue Mar 13 20:49:03 2018

@author: 
"""

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt

##產生訓練數據,生成模型爲2*x+5+random.randint(50)
x=np.arange(0.,10.,0.2)
m=len(x)
print(m)
x0=np.full(m,1.0)    
input_data=np.vstack([x0,x]).T  
target_data=2*input_data[:,1]+5*input_data[:,0]+np.random.randn(m)

loop_max=100000             #設置最大訓練次數,防止程序死循環
epsilon=1e-3               # 設置訓練模型容許誤差

np.random.seed(0)           #設置隨機產生種子,讓每次生成隨機數一致
theta=np.random.randn(2)     # 初始化訓練模型的權重

alpha=0.001                  #訓練速度(太大容易導致欠擬合,太小容易導致模型不收斂)
diff=0. 
error = np.zeros(2)           #初始化模型誤差
count=0                      #統計模型循環次數
finish=0                     #模型訓練截止標誌
minibatch_size=5            #小批量採樣間隔(也可以每次只採樣這幾個數據)
while count<loop_max:
    count+=1
    
    for i in range(1,m,minibatch_size):
        sum_m=np.zeros(2)
        k=np.random.randint(0,49)         # 隨機選取數據更新模型權重
        dif=(np.dot(theta,input_data[k])-target_data[k])*input_data[k]
        sum_m=sum_m+dif
        theta=theta-alpha*(1.0/minibatch_size)*sum_m
  ##跳出循環條件      
    if np.linalg.norm(theta-error)<epsilon:
        finish=1
        break
    else:
        error=theta
    print('loopcount=%d' %count, '\tw:', theta)
print('loop count=%d' %count,'\tw:',theta)
#根據python scipy 庫中stats線性擬合函數來驗證模型的正確性

slope, intercept, r_value, p_value,slope_std_error = stats.linregress(x, target_data)  
print ('intercept = %s slope = %s'% (intercept, slope) )
 
 #畫出訓練數據與擬合的模型圖
plt.plot(x, target_data, 'g*')  
plt.plot(x, theta[1]* x +theta[0],'r')  
plt.show()  
            












發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章