import os
from pymysql import *
import csv
import numpy as np
import matplotlib.pyplot as plt
def mysql():
# 建立與數據庫的連接
cnn = connect(host="localhost", port=3306,\
user="root", password="root",\
database="test", charset="utf8")
# 創建遊標
cur = cnn.cursor()
# 執行MySQL語句
table = cur.execute("SELECT * FROM diamonds")
diamonds = cur.fetchmany(table)
cur.close() # 關閉遊標
cnn.close() # 關閉連接
return diamonds
def creatfile(diamonds):
diamonds_mat = np.array(diamonds)
if not os.path.isdir("diamonds"):
os.mkdir("diamonds")
if not os.path.isfile("diamonds\\diamonds.csv"):
f = open("diamonds\\diamonds.csv", mode='w', newline='')
csv_writer = csv.writer(f)
for line in diamonds_mat:
csv_writer.writerow(line)
f.close()
else:
pass
def loadDataSet(filename, delim=","):
f = open(filename)
stringArr = [lines.strip().split(delim) for lines in f.readlines()]
datArr = [list(map(float, line)) for line in stringArr]
datMat = np.mat(datArr)
m, n = np.shape(datMat)
return datMat, m, n
diamonds = mysql()
creatfile(diamonds)
datMat, m, n = loadDataSet("diamonds\\diamonds.csv")
def regressionFunc(a, b, x): # 求解預測值
y_ = a + b * x
return y_
def lossFunc(a, b, m): # 代價函數
y = datMat[:, 1]
y_ = regressionFunc(a, b, datMat[:, 0])
J = (1 / (2 * m)) * np.sum(np.square(y_ - y))
return J
def aGrad(a, b, datMat, m): # a的偏導數
y = datMat[:, 1]
a = (1 / m) * np.sum(regressionFunc(a, b, datMat[:, 0]) - y)
return a
def bGrad(a, b, datMat, m): # b的偏導數
y = datMat[:, 1]
b = (1 / m) * np.sum(np.multiply((regressionFunc(a, b, datMat[:, 0]) - y), datMat[:, 0]))
return b
def gradientFunc(a, b, datMat, alfa, iter): # 梯度下降法
tempJ = np.zeros((iter, 1))
finishIter = iter
for i in range(iter):
tempJ[i] = lossFunc(a, b, m)
tempa = a - alfa * aGrad(a, b, datMat, m)
tempb = b - alfa * bGrad(a, b, datMat, m)
lasta = a
lastb = b
a = tempa
b = tempb
if lasta == a and lastb == b:
finishIter = i
J = lossFunc(a, b, m)
break
return a, b, tempJ, finishIter
a, b, tempJ, finishIter = gradientFunc(-2000, 40000, datMat, 0.01, 10000)
x = np.linspace(0, 2.5, 1000)
y = a + b * x
fig = plt.figure()
ax1 = fig.add_subplot(121)
ax1.scatter(datMat[:, 0].tolist(), datMat[:, 1].tolist(), color="m", s=50)
ax1.plot(x, y, color='darkorange', linewidth=3)
plt.xlabel("x")
plt.ylabel("y")
ax2 = fig.add_subplot(122)
ax2.plot(tempJ[:finishIter], linewidth=2.5)
plt.xlabel("Iteration Times")
plt.ylabel("Cost Function")
plt.show()
應用案例: