# -*- coding: utf-8 -*-
"""
Created on Tue Mar 13 20:49:03 2018
@author:
"""
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
##產生訓練數據,生成模型爲2*x+5+random.randint(50)
x=np.arange(0.,10.,0.2)
m=len(x)
print(m)
x0=np.full(m,1.0)
input_data=np.vstack([x0,x]).T
target_data=2*input_data[:,1]+5*input_data[:,0]+np.random.randn(m)
loop_max=100000 #設置最大訓練次數,防止程序死循環
epsilon=1e-3 # 設置訓練模型容許誤差
np.random.seed(0) #設置隨機產生種子,讓每次生成隨機數一致
theta=np.random.randn(2) # 初始化訓練模型的權重
alpha=0.001 #訓練速度(太大容易導致欠擬合,太小容易導致模型不收斂)
diff=0.
error = np.zeros(2) #初始化模型誤差
count=0 #統計模型循環次數
finish=0 #模型訓練截止標誌
minibatch_size=5 #小批量採樣間隔(也可以每次只採樣這幾個數據)
while count<loop_max:
count+=1
for i in range(1,m,minibatch_size):
sum_m=np.zeros(2)
k=np.random.randint(0,49) # 隨機選取數據更新模型權重
dif=(np.dot(theta,input_data[k])-target_data[k])*input_data[k]
sum_m=sum_m+dif
theta=theta-alpha*(1.0/minibatch_size)*sum_m
##跳出循環條件
if np.linalg.norm(theta-error)<epsilon:
finish=1
break
else:
error=theta
print('loopcount=%d' %count, '\tw:', theta)
print('loop count=%d' %count,'\tw:',theta)
#根據python scipy 庫中stats線性擬合函數來驗證模型的正確性
slope, intercept, r_value, p_value,slope_std_error = stats.linregress(x, target_data)
print ('intercept = %s slope = %s'% (intercept, slope) )
#畫出訓練數據與擬合的模型圖
plt.plot(x, target_data, 'g*')
plt.plot(x, theta[1]* x +theta[0],'r')
plt.show()