『自己的工作3』梯度下降實現SVM多分類+最詳細的數學推導+Python實戰(鳶尾花數據集)!

梯度下降實現SVM多分類+最詳細的數學推導+Python實戰(鳶尾花數據集)!

一. SVM梯度公式詳細推導

支持向量機(Support Vector Machine, SVM)的基本模型是在特徵空間上找到最佳的分離超平面使得訓練集上正負樣本間隔最大。SVM的目標是尋找一個最優化超平面在空間中分割兩類數據,這個最優化超平面需要滿足的條件是:離其最近的點到其的距離最大化,這些點被稱爲支持向量。SVM是用來解決二分類問題的有監督學習算法,同時它可以通過one-vs-all策略應用到多分類問題中。本文主要介紹如何使用梯度下降法對SVM多分類問題進行優化。

1.1. SVM多分類模型

假設數據集 XRk×n\mathbf{X} \in \mathrm{R}^{k \times n}nn 爲訓練樣本的個數,kk 爲每個樣本的維度。另外注意:下面使用的是L2-SVM!
L(wc,bc)=cC[12wcTwc+λiNmax{0,1yic(wcTxi+bc)}2]=12cCwcTwc+λcCiNmax{0,1yic(wcTxi+bc)}2(1) \begin{aligned} \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c}) &=\sum_{c}^{C}\left[\frac{1}{2}\boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{i}^{N} \max \left\{0,1-y_{i}^{c}\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)\right\}^{2}\right] \\ &=\frac{1}{2}\sum_{c}^{C} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c}^{C} \sum_{i}^{N} \max \left\{0,1-y_{i}^{c}\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)\right\}^2 \tag{1} \end{aligned}

1.2. SVM多分類梯度公式推導

首先,當 1yi(wTxi+b)<01-y_{i}\left(\boldsymbol w^{T} \boldsymbol x_{i}+b\right)<0 的樣本,此時相當於分類正確的情況,不需要加上 Hinge-Loss,因此我們有如下:
L(wc,bc)=12cCwcTwc=tr(WTW)(2) \begin{aligned} \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c}) = \frac{1}{2}\sum_{c}^{C} \boldsymbol w_{c}^{T} \boldsymbol w_{c}=\operatorname{tr}\left(\mathbf{W}^{\mathrm{T}} \mathbf{W}\right) \tag{2} \end{aligned} L(wc,bc)wc=wc(3) \begin{aligned} \frac{\partial \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c})}{\partial \boldsymbol w_{c}} =\boldsymbol w_{c} \tag{3} \end{aligned}
其次,當1yi(wTxi+b)>01-y_{i}\left(\boldsymbol w^{T} \boldsymbol x_{i}+b\right)>0 的樣本,此時相當於分類不正確的情況,需要加上 Hinge-Loss,因此我們有如下:
L(wc,bc)=cC[12wcTwc+λiNmax{0,1yic(wcTxi+bc)}2]=12cCwcTwc+λcCiN1yic(wcTxi+bc)2=12cCwcTwc+λcCiN[1+(wcTxi+bc)22yic(wcTxi+bc)]=12cwcTwc+λc=1Ci=1N[1+wcTxixiTwc+bc2+2wcTxibc2yicwcTxi2yicbc]=12cwcTwc+λc=1c[n(1+bc2)+wcT(i=1NxixiT)wc+2bcwcT(i=1Nxi)2wcT(i=1Nxiyic)2(i=1Nyic)bc](4) \begin{aligned} \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c}) &=\sum_{c}^{C}\left[\frac{1}{2}\boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{i}^{N} \max \left\{0,1-y_{i}^{c}\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)\right\}^{2}\right] \\ &=\frac{1}{2}\sum_{c}^{C} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c}^{C} \sum_{i}^{N}\left\|1-y_{i}^{c}\left(\boldsymbol w_{c}^{T} x_{i}+b_{c}\right)\right\|^{2} \\ &=\frac{1}{2} \sum_{c}^{C} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c}^{C} \sum_{i}^{N}\left[1+\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)^{2}-2 y_{i}^{c}\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)\right] \\& =\frac{1}{2}\sum_{c} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c=1}^{C} \sum_{i=1}^{N}\left[1+\boldsymbol w_{c}^{T} \boldsymbol x_{i} \boldsymbol x_{i}^{T} \boldsymbol w_{c}+b_{c}^{2}+2 \boldsymbol w_{c}^{T} \boldsymbol x_{i} b_{c}-2 y_{i}^{c} \boldsymbol w_{c}^{T} \boldsymbol x_{i}-2 y_{i}^{c} b_{c}\right] \\&=\frac{1}{2}\sum_{c} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c=1}^{c}\left[n\left(1+b_{c}^{2}\right)+\boldsymbol w_{c}^{T}\left(\sum_{i=1}^{N}\boldsymbol x_{i}\boldsymbol x_{i}^{T}\right) \boldsymbol w_{c}+2 b_{c} \boldsymbol w_{c}^{T}\left(\sum_{i=1}^{N} \boldsymbol x_{i}\right)-2 \boldsymbol w_{c}^{T}\left(\sum_{i=1}^{N} \boldsymbol x_{i} y_{i}^{c}\right)-2\left(\sum_{i=1}^{N} y_{i}^{c}\right) b_{c}\right] \tag{4} \end{aligned}

然後,整理上面可以得到:
L(wc,bc)=12iwcTwc+λi[n(1+bc2)+wcTXXTwc+2bcwcTXE2wcTXyc2bcETyc](5) \begin{aligned} \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c}) &= \frac{1}{2}\sum_{i}\boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{i}\left[n\left(1+b_{c}^{2}\right)+\boldsymbol w_{c}^{T} \mathbf{X} \mathbf{X}^{\mathrm{T}} \boldsymbol w_{c}+2 b_{c} \boldsymbol w_{c}^{T} \mathbf{X} \mathbf{E}-2 \boldsymbol w_{c}^{T} \mathbf{X} \mathbf{y}_{c}-2 b_{c} \mathbf{E}^{T} \mathbf{y}_{c}\right] \tag{5} \end{aligned}

此外,上述公式 (5)(5) 還可以繼續化簡,這裏只是提供一個思路!
L(W,b)=12tr(WTW)+λ[n(c+bTb)+tr(WTXXTW)+2bTWTXE2tr(WTXYT)2bTYE](6) \begin{aligned} \mathcal{L}(\mathbf{\mathbf{W}}, \mathbf{b}) =\frac{1}{2}\operatorname{tr}\left(\mathbf{W}^{\mathrm{T}} \mathbf{W}\right)+\lambda\left[n\left(c+\mathbf{b}^{\mathrm{T}} \mathbf{b}\right)+\operatorname{tr}\left(\mathbf{W}^{\mathrm{T}} \mathbf{X} \mathbf{X}^{\mathrm{T}} \mathbf{W}\right)+2 \mathbf{b}^{\mathrm{T}} \mathbf{W}^{\mathrm{T}} \mathbf{X} \mathbf{E}-2 \operatorname{tr}\left(\mathbf{W}^{\mathrm{T}} \mathbf{X} \mathbf{Y}^{\mathrm{T}}\right)-2 \mathbf{b}^{\mathrm{T}} \mathbf{Y} \mathbf{E}\right] \tag{6} \end{aligned}
目標函數 (5)(5) 分別對 wc\boldsymbol w_{c}bcb_c 求偏導數,可以得到如下:
L(wc,bc)wc=wc+2XXTwc+2XEbc2Xyc(7) \frac{\partial \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c})}{\partial \boldsymbol w_{c}} = \boldsymbol w_{c}+2 \mathbf{X} \mathbf{X}^{\mathrm{T}} \boldsymbol w_{c}+2 \mathbf{X} \mathbf{E} b_{c}-2 \mathbf{X}{\mathbf y}_{c}\tag{7} L(wc,bc)bc=2nbc2ycTE(8) \frac{\partial \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c})}{\partial b_{c}} = 2 n b_{c}-2{\mathbf y}_{c}^{T} \mathbf{E}\tag{8}

Xk×n=[x1(1)x1(2).x1(n)............xk(1)xk(2).xk(n)]k×n{\mathbf{X} }_{k \times n}=\left[\begin{array}{cccc}{ x_{1}^{(1)}} & { x_{1}^{(2)}} & {.} & { x_{1}^{(n)}} \\ {.} & {.} & {.} & {.} \\ {.} & {.} & {.} & {.} \\ {.} & {.} & {.} & {.} \\ { x_{k}^{(1)}} & { x_{k}^{(2)}} & {.} & { x_{k}^{(n)}}\end{array}\right]_{k \times n} YC×n=[y1(1)y2(1).yn(1)y1(C)y2(C)yn(C)]C×n{\mathbf{Y}}_{{C \times n}}=\left[\begin{array}{cccc}{y_{1}^{(1)}} & {y_{2}^{(1)}} & {.} & {y_{n}^{(1)}} \\ {\cdot} & {\cdot} & {\cdot} & {\cdot} \\ {\cdot} & {\cdot} & {\cdot} & {\cdot} \\ {y_{1}^{(C)}} & {y_{2}^{(C)}} & {\cdot} & {y_{\mathrm{n}}^{(C)}}\end{array}\right]_{{C \times n}} E=[11]n×1\mathbf{E}=\left[\begin{array}{l}{1} \\ {\cdot} \\ {\cdot} \\ {1}\end{array}\right]_{n \times 1}
Wk×C=[w1(1)w1(2).w1(C)........wk(1)wk(2).wk(C)]k×C\mathbf{W}_{k \times C}=\left[\begin{array}{lll}{w_{1}^{(1)}} & {w_{1}^{(2)}} & {.} & {w_{1}^{(C)}} \\ {.} & {.} & {.} & {.} \\ {.} & {.} & {.} & {.} \\ {w_{k}^{(1)}} & {w_{k}^{(2)}} & {.} & {w_{k}^{(C)}}\end{array}\right]_{k \times C} b=[b1..bC]C×1\mathbf{b}=\left[\begin{array}{l}{b_{1}} \\ {.} \\ {.} \\ {b_{C}}\end{array}\right]_{C \times 1} yc=[y1(c)yn(c)]n×1\boldsymbol{y}_{\mathrm{c}}=\left[\begin{array}{c}{y_{1}^{(c)}} \\ {\cdot} \\ {\cdot} \\ {y_{n}^{(c)}}\end{array}\right]_{n \times 1}

1.3. SVM多分類算法優化過程

整體迭代過程非常容易理解,主要分爲以下兩個模塊(具體的過程看2.2章節的代碼實現):

  • 如果進來的樣本不滿足條件1yi(wTxi+b)>0,1-y_{i}\left(\boldsymbol w^{T} \boldsymbol x_{i}+b\right)>0, 那麼將盡可能能往滿足條件的方向優化(此時使用 Hinge-Loss在SVM的原問題空間對問題進行優化)。
  • 如果進來的樣本符條件1yi(wTxi+b)<0,1-y_{i}\left(\boldsymbol w^{T} \boldsymbol x_{i}+b\right)<0, 那麼參數保持不變。

二. Python代碼實戰

2.1. 鳶尾花數據集介紹

Iris 鳶尾花數據集包含3 類共 150 條記錄,每類各 50 個數據,每條記錄都有 4 項特徵:花萼長度、花萼寬度、花瓣長度、花瓣寬度,可以通過這4個特徵預測鳶尾花卉屬於(iris-setosa, iris-versicolour, iris-virginica)中的哪一品種。

  • 這裏就以鳶尾花數據集爲例one.txt:
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

2.2. Python代碼實現SVM多分類

  • 程序代碼如下:
import numpy as np

batchsz = 150

def obtain_w_via_gradient_descent(x, c, y, penalty_c, threshold = 1e-19, learn_rate = 1e-4):
    """ 利用梯度下降法求解如下的SVM問題:min 1/2 * w^T * w + C * Σ_i=1:n(max(0, 1 - y_i * (w^T * x_i + b)))
    :param x: 訓練樣本 x = [x_1, x_2, ..., x_i]
    :param c: 類別數
    :param y: 樣本標籤 y = [y_1, y_2, ..., y_c]
    :param threshold: 梯度下降停止閾值
    """
    data_num = np.shape(x)[1]
    feature_dim = np.shape(x)[0]
    w = np.ones([feature_dim, c], dtype=np.float32)
    b = np.ones([c, 1], dtype=np.float32)
    dl_dw = np.zeros([feature_dim, c], dtype=np.float)
    dl_db = np.zeros([c, 1], dtype=np.float)
    it = 1
    th = 0.1
    while it < 60000 and th > threshold:
        a = np.tile(b, [1, data_num])
        ksi = (np.transpose(w) @ x + np.tile(b, [1, data_num])) * y
        index_martix = ksi < 1

        for class_num in range(c):
            index_vector = index_martix[class_num, :]

            if True in index_vector:
                x_c = x[:, index_vector]

                data_num_c = np.shape(x_c)[1]
                e = np.ones([data_num_c, 1], dtype=np.float)
                y_c = np.reshape(y[class_num, index_vector], [data_num_c, 1])
                w_c = np.reshape(w[:, class_num], [feature_dim, 1])
                b_c = b[class_num]

                dl_dw[:, class_num] = (w_c + 2 * penalty_c * (x_c @ np.transpose(x_c) @ w_c +
                                                              x_c @ e * b_c -
                                                              x_c @ y_c))[:, 0]
                dl_db[class_num, 0] = 2 * penalty_c * (b_c * data_num_c +
                                                       np.transpose(w_c) @ x_c @ e -
                                                       np.transpose(y_c) @ e)
            else:
                w_c = np.reshape(w[:, class_num], [feature_dim, 1])
                dl_dw[:, class_num] = w_c[:, 0]
                dl_db[class_num, 0] = 0

        w_ = w - learn_rate * (dl_dw / np.linalg.norm(dl_dw, ord=2))
        b_ = b - learn_rate * dl_db

        th = np.sum(np.square(w_ - w)) + np.sum(np.square(b_ - b))
        it = it + 1

        w = w_
        b = b_

        if it % 200 == 0:
            y_predict = np.transpose(w) @ x + np.tile(b, [1, data_num])
            correct_prediction = np.equal(np.argmax(y_predict, 0), np.argmax(y, 0))
            accuracy = np.mean(correct_prediction.astype(np.float))
            print("epoch:", it, "acc:", accuracy)

def iris_type(s):
    it = {b'Iris-setosa': 0, b'Iris-versicolor': 1, b'Iris-virginica': 2}    # b'Iris-virginica': 2
    return it[s]

def normalize_data(data):
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    for i in range(data.shape[0]):
        data[i, :] = (data[i, :] - mean) / std
    return  data

def convert_to_one_hot(y, C):
    return np.eye(C)[y.reshape(-1)]

def main():
    data = np.loadtxt('./one.txt', dtype=float, delimiter=',', converters={4: iris_type})  #
    x = data[:, :4]
    x = normalize_data(x)     # 預處理數據
    y = data[:, 4]
    y = y.astype(np.int)
    y_onehot = convert_to_one_hot(y, 3)
    y_onehot[y_onehot == 0] = -1

    x = np.transpose(x)                         # k*n  k: 特徵維度, n: 樣本數
    y_onehot = np.transpose(y_onehot)           # c*n  c: 類別數,   n: 樣本數
    w = np.array([[1, 1, 1], [1, 1, 1]])        # k*c  k: 特徵維度,  c: 類別數
    b = np.array([[1],[1],[1]])                 # c*1  c: 類別數
    obtain_w_via_gradient_descent(x, 3, y_onehot, 0.5)

if __name__ == '__main__':
    main()

2.3. 程序最終運行結果

  • 程序最終運行結果如下:
ssh://zhangkf@192.168.136.64:22/home/zhangkf/anaconda3/envs/py1/bin/python -u /home/zhangkf/johnCodes/TF1/svm_test/SVM_grad.py
epoch: 200 acc: 0.6666666666666666
epoch: 400 acc: 0.3333333333333333
epoch: 600 acc: 0.3333333333333333
epoch: 800 acc: 0.3333333333333333
epoch: 1000 acc: 0.3333333333333333
epoch: 1200 acc: 0.3333333333333333
epoch: 1400 acc: 0.3333333333333333
epoch: 1600 acc: 0.3333333333333333
epoch: 1800 acc: 0.3333333333333333
epoch: 2000 acc: 0.3333333333333333
epoch: 2200 acc: 0.3333333333333333
epoch: 2400 acc: 0.3333333333333333
epoch: 2600 acc: 0.34
epoch: 2800 acc: 0.34
epoch: 3000 acc: 0.36
epoch: 3200 acc: 0.36666666666666664
epoch: 3400 acc: 0.38
epoch: 3600 acc: 0.3933333333333333
epoch: 3800 acc: 0.4
epoch: 4000 acc: 0.4266666666666667
epoch: 4200 acc: 0.43333333333333335
epoch: 4400 acc: 0.47333333333333333
epoch: 4600 acc: 0.5
epoch: 4800 acc: 0.5066666666666667
epoch: 5000 acc: 0.52
epoch: 5200 acc: 0.52
epoch: 5400 acc: 0.52
epoch: 5600 acc: 0.54
epoch: 5800 acc: 0.5533333333333333
epoch: 6000 acc: 0.5733333333333334
epoch: 6200 acc: 0.58
epoch: 6400 acc: 0.58
epoch: 6600 acc: 0.58
epoch: 6800 acc: 0.5866666666666667
epoch: 7000 acc: 0.5866666666666667
epoch: 7200 acc: 0.5933333333333334
epoch: 7400 acc: 0.5933333333333334
epoch: 7600 acc: 0.6066666666666667
epoch: 7800 acc: 0.6066666666666667
epoch: 8000 acc: 0.6266666666666667
epoch: 8200 acc: 0.6333333333333333
epoch: 8400 acc: 0.64
epoch: 8600 acc: 0.64
epoch: 8800 acc: 0.6466666666666666
epoch: 9000 acc: 0.6533333333333333
epoch: 9200 acc: 0.66
epoch: 9400 acc: 0.66
epoch: 9600 acc: 0.66
epoch: 9800 acc: 0.66
epoch: 10000 acc: 0.6466666666666666
epoch: 10200 acc: 0.6533333333333333
epoch: 10400 acc: 0.6533333333333333
epoch: 10600 acc: 0.6533333333333333
epoch: 10800 acc: 0.6533333333333333
epoch: 11000 acc: 0.6533333333333333
epoch: 11200 acc: 0.6533333333333333
epoch: 11400 acc: 0.66
epoch: 11600 acc: 0.66
epoch: 11800 acc: 0.66
epoch: 12000 acc: 0.6666666666666666
epoch: 12200 acc: 0.6733333333333333
epoch: 12400 acc: 0.6866666666666666
epoch: 12600 acc: 0.6866666666666666
epoch: 12800 acc: 0.6866666666666666
epoch: 13000 acc: 0.6866666666666666
epoch: 13200 acc: 0.6866666666666666
epoch: 13400 acc: 0.6933333333333334
epoch: 13600 acc: 0.7133333333333334
epoch: 13800 acc: 0.72
epoch: 14000 acc: 0.7333333333333333
epoch: 14200 acc: 0.74
epoch: 14400 acc: 0.7466666666666667
epoch: 14600 acc: 0.7533333333333333
epoch: 14800 acc: 0.76
epoch: 15000 acc: 0.76
epoch: 15200 acc: 0.7666666666666667
epoch: 15400 acc: 0.7666666666666667
epoch: 15600 acc: 0.7666666666666667
epoch: 15800 acc: 0.7666666666666667
epoch: 16000 acc: 0.78
epoch: 16200 acc: 0.78
epoch: 16400 acc: 0.7866666666666666
epoch: 16600 acc: 0.7933333333333333
epoch: 16800 acc: 0.7933333333333333
epoch: 17000 acc: 0.7933333333333333
epoch: 17200 acc: 0.7933333333333333
epoch: 17400 acc: 0.8066666666666666
epoch: 17600 acc: 0.8066666666666666
epoch: 17800 acc: 0.82
epoch: 18000 acc: 0.8266666666666667
epoch: 18200 acc: 0.82
epoch: 18400 acc: 0.82
epoch: 18600 acc: 0.8266666666666667
epoch: 18800 acc: 0.8266666666666667
epoch: 19000 acc: 0.8266666666666667
epoch: 19200 acc: 0.8266666666666667
epoch: 19400 acc: 0.8333333333333334
epoch: 19600 acc: 0.8333333333333334
epoch: 19800 acc: 0.8333333333333334
epoch: 20000 acc: 0.8333333333333334
epoch: 20200 acc: 0.8333333333333334
epoch: 20400 acc: 0.8466666666666667
epoch: 20600 acc: 0.8533333333333334
epoch: 20800 acc: 0.86
epoch: 21000 acc: 0.8666666666666667
epoch: 21200 acc: 0.8666666666666667
epoch: 21400 acc: 0.8666666666666667
epoch: 21600 acc: 0.8666666666666667
epoch: 21800 acc: 0.8666666666666667
epoch: 22000 acc: 0.8666666666666667
epoch: 22200 acc: 0.8666666666666667
epoch: 22400 acc: 0.8666666666666667
epoch: 22600 acc: 0.8666666666666667
epoch: 22800 acc: 0.8666666666666667
epoch: 23000 acc: 0.8666666666666667
epoch: 23200 acc: 0.8666666666666667
epoch: 23400 acc: 0.8666666666666667
epoch: 23600 acc: 0.8666666666666667
epoch: 23800 acc: 0.8666666666666667
epoch: 24000 acc: 0.8666666666666667
epoch: 24200 acc: 0.8666666666666667
epoch: 24400 acc: 0.8666666666666667
epoch: 24600 acc: 0.8666666666666667
epoch: 24800 acc: 0.8666666666666667
epoch: 25000 acc: 0.8666666666666667
epoch: 25200 acc: 0.8666666666666667
epoch: 25400 acc: 0.8666666666666667
epoch: 25600 acc: 0.8666666666666667
epoch: 25800 acc: 0.8666666666666667
epoch: 26000 acc: 0.8666666666666667
epoch: 26200 acc: 0.8666666666666667
epoch: 26400 acc: 0.8666666666666667
epoch: 26600 acc: 0.86
epoch: 26800 acc: 0.86
epoch: 27000 acc: 0.86
epoch: 27200 acc: 0.86
epoch: 27400 acc: 0.86
epoch: 27600 acc: 0.86
epoch: 27800 acc: 0.86
epoch: 28000 acc: 0.86
epoch: 28200 acc: 0.8666666666666667
epoch: 28400 acc: 0.86
epoch: 28600 acc: 0.86
epoch: 28800 acc: 0.86
epoch: 29000 acc: 0.86
epoch: 29200 acc: 0.86
epoch: 29400 acc: 0.86
epoch: 29600 acc: 0.86
epoch: 29800 acc: 0.86
epoch: 30000 acc: 0.8533333333333334
epoch: 30200 acc: 0.8533333333333334
epoch: 30400 acc: 0.86
epoch: 30600 acc: 0.86
epoch: 30800 acc: 0.8666666666666667
epoch: 31000 acc: 0.8666666666666667
epoch: 31200 acc: 0.8666666666666667
epoch: 31400 acc: 0.8666666666666667
epoch: 31600 acc: 0.86
epoch: 31800 acc: 0.86
epoch: 32000 acc: 0.86
epoch: 32200 acc: 0.86
epoch: 32400 acc: 0.86
epoch: 32600 acc: 0.86
epoch: 32800 acc: 0.86
epoch: 33000 acc: 0.86
epoch: 33200 acc: 0.86
epoch: 33400 acc: 0.86
epoch: 33600 acc: 0.86
epoch: 33800 acc: 0.86
epoch: 34000 acc: 0.8666666666666667
epoch: 34200 acc: 0.8666666666666667
epoch: 34400 acc: 0.8666666666666667
epoch: 34600 acc: 0.8666666666666667
epoch: 34800 acc: 0.8666666666666667
epoch: 35000 acc: 0.8666666666666667
epoch: 35200 acc: 0.8866666666666667
epoch: 35400 acc: 0.8866666666666667
epoch: 35600 acc: 0.8933333333333333
epoch: 35800 acc: 0.9
epoch: 36000 acc: 0.9
epoch: 36200 acc: 0.9
epoch: 36400 acc: 0.9
epoch: 36600 acc: 0.9
epoch: 36800 acc: 0.9
epoch: 37000 acc: 0.9
epoch: 37200 acc: 0.9
epoch: 37400 acc: 0.9
epoch: 37600 acc: 0.9
epoch: 37800 acc: 0.9
epoch: 38000 acc: 0.9
epoch: 38200 acc: 0.9066666666666666
epoch: 38400 acc: 0.9066666666666666
epoch: 38600 acc: 0.9133333333333333
epoch: 38800 acc: 0.9133333333333333
epoch: 39000 acc: 0.9133333333333333
epoch: 39200 acc: 0.9133333333333333
epoch: 39400 acc: 0.9133333333333333
epoch: 39600 acc: 0.9133333333333333
epoch: 39800 acc: 0.9133333333333333
epoch: 40000 acc: 0.9133333333333333
epoch: 40200 acc: 0.9133333333333333
epoch: 40400 acc: 0.9133333333333333
epoch: 40600 acc: 0.9133333333333333
epoch: 40800 acc: 0.9133333333333333
epoch: 41000 acc: 0.9133333333333333
epoch: 41200 acc: 0.9133333333333333
epoch: 41400 acc: 0.9133333333333333
epoch: 41600 acc: 0.9266666666666666
epoch: 41800 acc: 0.9266666666666666
epoch: 42000 acc: 0.9266666666666666
epoch: 42200 acc: 0.9333333333333333
epoch: 42400 acc: 0.9333333333333333
epoch: 42600 acc: 0.9333333333333333
epoch: 42800 acc: 0.9333333333333333
epoch: 43000 acc: 0.9333333333333333
epoch: 43200 acc: 0.9333333333333333
epoch: 43400 acc: 0.9333333333333333
epoch: 43600 acc: 0.9333333333333333
epoch: 43800 acc: 0.9333333333333333
epoch: 44000 acc: 0.9333333333333333
epoch: 44200 acc: 0.9333333333333333
epoch: 44400 acc: 0.9333333333333333
epoch: 44600 acc: 0.9333333333333333
epoch: 44800 acc: 0.94
epoch: 45000 acc: 0.94
epoch: 45200 acc: 0.94
epoch: 45400 acc: 0.94
epoch: 45600 acc: 0.9466666666666667
epoch: 45800 acc: 0.9466666666666667
epoch: 46000 acc: 0.9466666666666667
epoch: 46200 acc: 0.9466666666666667
epoch: 46400 acc: 0.9466666666666667
epoch: 46600 acc: 0.9466666666666667
epoch: 46800 acc: 0.9466666666666667
epoch: 47000 acc: 0.9466666666666667
epoch: 47200 acc: 0.9466666666666667
epoch: 47400 acc: 0.9466666666666667
epoch: 47600 acc: 0.9466666666666667
epoch: 47800 acc: 0.9466666666666667
epoch: 48000 acc: 0.9466666666666667
epoch: 48200 acc: 0.9466666666666667
epoch: 48400 acc: 0.9466666666666667
epoch: 48600 acc: 0.9466666666666667
epoch: 48800 acc: 0.9466666666666667
epoch: 49000 acc: 0.9466666666666667
epoch: 49200 acc: 0.9466666666666667
epoch: 49400 acc: 0.9466666666666667
epoch: 49600 acc: 0.9466666666666667
epoch: 49800 acc: 0.9466666666666667
epoch: 50000 acc: 0.9466666666666667
epoch: 50200 acc: 0.9466666666666667
epoch: 50400 acc: 0.9466666666666667
epoch: 50600 acc: 0.9466666666666667
epoch: 50800 acc: 0.9466666666666667
epoch: 51000 acc: 0.9466666666666667
epoch: 51200 acc: 0.9466666666666667
epoch: 51400 acc: 0.9466666666666667
epoch: 51600 acc: 0.9533333333333334
epoch: 51800 acc: 0.9533333333333334
epoch: 52000 acc: 0.96
epoch: 52200 acc: 0.96
epoch: 52400 acc: 0.96
epoch: 52600 acc: 0.96
epoch: 52800 acc: 0.96
epoch: 53000 acc: 0.96
epoch: 53200 acc: 0.96
epoch: 53400 acc: 0.96
epoch: 53600 acc: 0.96
epoch: 53800 acc: 0.96
epoch: 54000 acc: 0.96
epoch: 54200 acc: 0.96
epoch: 54400 acc: 0.96
epoch: 54600 acc: 0.96
epoch: 54800 acc: 0.96
epoch: 55000 acc: 0.96
epoch: 55200 acc: 0.96
epoch: 55400 acc: 0.96
epoch: 55600 acc: 0.96
epoch: 55800 acc: 0.96
epoch: 56000 acc: 0.96
epoch: 56200 acc: 0.96
epoch: 56400 acc: 0.96
epoch: 56600 acc: 0.96
epoch: 56800 acc: 0.96
epoch: 57000 acc: 0.96
epoch: 57200 acc: 0.96
epoch: 57400 acc: 0.96
epoch: 57600 acc: 0.96
epoch: 57800 acc: 0.96
epoch: 58000 acc: 0.96
epoch: 58200 acc: 0.96
epoch: 58400 acc: 0.96
epoch: 58600 acc: 0.96
epoch: 58800 acc: 0.96
epoch: 59000 acc: 0.96
epoch: 59200 acc: 0.96
epoch: 59400 acc: 0.96
epoch: 59600 acc: 0.96
epoch: 59800 acc: 0.96
epoch: 60000 acc: 0.96

Process finished with exit code 0

三. 隨機梯度下降實現

import numpy as np
from time import *

batchsz = 500
np.random.seed(0)

# 0. 定義函數實現mini_batch
def mini_batches(X, Y, mini_batch_size=batchsz, seed=0):
    np.random.seed(seed)
    m = X.shape[0]                                # m是樣本數

    mini_batches = []                             # 用來存放一個一個的mini_batch

    num_complete_minibatches = int(m // mini_batch_size)  # 樣本總數除以每個batch的樣本數量
    for i in range(num_complete_minibatches):
        mini_batch_X = X[i * mini_batch_size:(i + 1) * mini_batch_size, :]
        mini_batch_Y = Y[i * mini_batch_size:(i + 1) * mini_batch_size, :]
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)

    if m % mini_batch_size != 0:
        # 如果樣本數不能被整除,取餘下的部分
        mini_batch_X = X[num_complete_minibatches * mini_batch_size:, :]
        mini_batch_Y = Y[num_complete_minibatches * mini_batch_size, :]
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)
    return mini_batches


# mini_batches = mini_batches(X_train, y_train, mini_batch_size=64, seed=0)
#
# mini_batches[780][0].shape
# (64, 32, 32, 3)


# 1. 隨機梯度下降法實現優化SVM
def obtain_w_via_gradient_descent(x, c, y, penalty_c, x_test, y_test_onehot, threshold = 1e-19, learn_rate = 1e-4):
    """ 利用梯度下降法求解如下的SVM問題:min 1/2 * w^T * w + C * Σ_i=1:n(max(0, 1 - y_i * (w^T * x_i + b)))
    :param x: 訓練樣本 x = [x_1, x_2, ..., x_i]
    :param c: 類別數
    :param y: 樣本標籤 y = [y_1, y_2, ..., y_c]
    :param threshold: 梯度下降停止閾值
    """
    data_num = np.shape(x)[1]
    feature_dim = np.shape(x)[0]
    w = np.ones([feature_dim, c], dtype=np.float32)
    b = np.ones([c, 1], dtype=np.float32)
    dl_dw = np.zeros([feature_dim, c], dtype=np.float)
    dl_db = np.zeros([c, 1], dtype=np.float)
    epoch = 1
    th = 0.1

    iterations = mini_batches(x.T, y.T, batchsz, seed=0)  # mini_batchs
    print(iterations[0][0].shape)

    begin_time = time()
    while epoch < 100000 and th > threshold:

        for x_y in iterations:
            x = x_y[0].T
            y = x_y[1].T
            a = np.tile(b, [1, batchsz])
            ksi = (np.transpose(w) @ x + np.tile(b, [1, batchsz])) * y
            index_martix = ksi < 1

            for class_num in range(c):
                index_vector = index_martix[class_num, :]

                if True in index_vector:
                    x_c = x[:, index_vector]

                    data_num_c = np.shape(x_c)[1]
                    e = np.ones([data_num_c, 1], dtype=np.float)
                    y_c = np.reshape(y[class_num, index_vector], [data_num_c, 1])
                    w_c = np.reshape(w[:, class_num], [feature_dim, 1])
                    b_c = b[class_num]

                    dl_dw[:, class_num] = (w_c + 2 * penalty_c * (x_c @ np.transpose(x_c) @ w_c +
                                                                  x_c @ e * b_c -
                                                                  x_c @ y_c))[:, 0]
                    dl_db[class_num, 0] = 2 * penalty_c * (b_c * data_num_c +
                                                           np.transpose(w_c) @ x_c @ e -
                                                           np.transpose(y_c) @ e)
                else:
                    w_c = np.reshape(w[:, class_num], [feature_dim, 1])
                    dl_dw[:, class_num] = w_c[:, 0]
                    dl_db[class_num, 0] = 0

            w_ = w - learn_rate * (dl_dw / np.linalg.norm(dl_dw, ord=2))
            b_ = b - learn_rate * dl_db

            th = np.sum(np.square(w_ - w)) + np.sum(np.square(b_ - b))
            epoch = epoch + 1

            w = w_
            b = b_

            #############################################################################
            if epoch % 100 == 0:                   # 訓練過程中準確率打印
                y_predict = np.transpose(w) @ x + np.tile(b, [1, batchsz])
                correct_prediction = np.equal(np.argmax(y_predict, 0), np.argmax(y, 0))
                accuracy = np.mean(correct_prediction.astype(np.float))
                print("epoch:", epoch, "acc:", accuracy)

    end_time = time()
    run_time = end_time - begin_time
    print('Run  time:', run_time)             # 該循環程序運行時間
    ########################################## 測試集結果 ############################
    data_num = np.shape(x_test)[1]

    y_predict = np.transpose(w) @ x_test + np.tile(b, [1, data_num])
    correct_prediction = np.equal(np.argmax(y_predict, 0), np.argmax(y_test_onehot, 0))
    accuracy = np.mean(correct_prediction.astype(np.float))
    print("Test_acc:", accuracy)


# 2. 歸一化數據
def normalize_data(data):
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    for i in range(data.shape[0]):
        data[i, :] = (data[i, :] - mean) / std
    return  data

# 3. 轉化爲one_hot編碼
def convert_to_one_hot(y, C):
    return np.eye(C)[y.reshape(-1)]

# 4. 隨機打散訓練數據和相應的標籤
def random_scattered(data):
    index = np.arange(data.shape[0])
    np.random.shuffle(index)
    data = data[index,:]
    return data


def main():
    # 1. 數據集加載
    data = np.loadtxt('one2.txt', dtype=float, delimiter=',')
    # 2. 隨機打散訓練數據和相應的標籤
    data = random_scattered(data)
    # 3. 拆分訓練數據和測試數據;
    train_num = int(0.75 * data.shape[0])
    data_train_label = data[:train_num, :]      # 訓練集75%
    data_test_lable = data[train_num + 1:, :]   # 測試集25%

    ################################### 訓練集 ###################################
    x = data_train_label[:, :4]
    # 3. 歸一化
    x = normalize_data(x)
    y = data_train_label[:, 4]
    y = y.astype(np.int)-1
    # 4. 轉換one_hot編碼
    y_onehot = convert_to_one_hot(y, 2)
    y_onehot[y_onehot == 0] = -1

    x = np.transpose(x)                         # k*n  k: 特徵維度, n: 樣本數
    y_onehot = np.transpose(y_onehot)           # c*n  c: 類別數,   n: 樣本數
    w = np.array([[1, 1, 1], [1, 1, 1]])        # k*c  k: 特徵維度,  c: 類別數
    b = np.array([[1],[1],[1]])                 # c*1  c: 類別數

    ################################### 測試集 ###################################
    x_test = data_test_lable[:, :4]
    # 3. 歸一化
    x_test = normalize_data(x_test)
    y_test = data_test_lable[:, 4]
    y_test = y_test.astype(np.int)-1
    # 4. 轉換one_hot編碼
    y_test_onehot = convert_to_one_hot(y_test, 2)
    y_test_onehot[y_test_onehot == 0] = -1

    x_test = np.transpose(x_test)                         # k*n  k: 特徵維度, n: 樣本數
    y_test_onehot = np.transpose(y_test_onehot)           # c*n  c: 類別數,   n: 樣本數


    obtain_w_via_gradient_descent(x, 2, y_onehot, 0.5, x_test, y_test_onehot)

if __name__ == '__main__':
    main()

  • one2.txt部分
6.64E+01,3.53E+02,5.35E+03,9.50E+02,1
6.18E+01,2.34E+02,4.77E+03,9.50E+02,1
6.71E+01,2.80E+02,5.44E+03,9.32E+02,1
6.30E+01,2.06E+02,4.89E+03,9.27E+02,1
6.67E+01,1.75E+02,5.37E+03,9.15E+02,1
7.54E+01,8.78E+01,6.60E+03,9.14E+02,1
6.83E+01,2.16E+02,5.58E+03,9.14E+02,1
6.58E+01,1.79E+02,5.24E+03,9.12E+02,1
6.93E+01,2.53E+02,5.71E+03,9.10E+02,1
6.73E+01,2.69E+02,5.44E+03,9.08E+02,1
6.20E+01,7.08E+01,4.75E+03,9.06E+02,1
6.70E+01,1.03E+02,5.39E+03,9.04E+02,1
6.43E+01,1.08E+02,5.04E+03,9.03E+02,1
6.73E+01,2.90E+02,5.43E+03,9.01E+02,1
6.64E+01,2.08E+02,5.31E+03,8.98E+02,1
6.52E+01,9.82E+01,5.15E+03,8.91E+02,1
6.45E+01,2.13E+02,5.05E+03,8.91E+02,1
6.21E+01,1.50E+02,4.75E+03,8.90E+02,1
6.72E+01,2.05E+02,5.40E+03,8.84E+02,1
6.33E+01,2.74E+02,4.89E+03,8.83E+02,1
6.46E+01,1.53E+02,5.05E+03,8.83E+02,1
6.48E+01,1.33E+02,4.94E+03,7.45E+02,2
7.55E+01,1.34E+02,6.44E+03,7.45E+02,2
7.04E+01,3.49E+02,5.70E+03,7.45E+02,2
7.03E+01,1.74E+02,5.69E+03,7.45E+02,2
6.70E+01,1.48E+02,5.23E+03,7.45E+02,2
7.25E+01,1.25E+02,6.00E+03,7.45E+02,2
6.72E+01,2.20E+02,5.26E+03,7.45E+02,2
7.56E+01,9.81E+01,6.46E+03,7.45E+02,2
7.54E+01,2.08E+02,6.43E+03,7.45E+02,2
6.83E+01,1.29E+02,5.41E+03,7.45E+02,2
6.35E+01,1.56E+02,4.78E+03,7.45E+02,2
6.45E+01,2.58E+02,4.90E+03,7.45E+02,2
6.44E+01,2.17E+02,4.89E+03,7.45E+02,2
6.42E+01,1.02E+02,4.87E+03,7.45E+02,2
6.09E+01,1.53E+02,4.45E+03,7.45E+02,2
6.66E+01,2.02E+02,5.18E+03,7.45E+02,2
6.53E+01,1.17E+02,5.01E+03,7.45E+02,2

補充:python中讀取文件txt文件的幾種方式

  • new.txt
1000025,5,1,1,1,2,1,3,1,1,2
1002945,5,4,4,5,7,10,3,2,1,2
1151734,10,8,7,4,3,10,7,9,1,4
1156017,3,1,1,1,2,1,2,1,1,2
1158247,1,1,1,1,1,1,1,1,1,2
1238021,1,1,1,1,2,1,2,1,1,2
1238464,1,1,1,1,1,?,2,1,1,2
1238633,10,10,10,6,8,4,8,5,1,4
1295186,10,10,10,1,6,1,2,8,1,4
527337,4,1,1,1,2,1,1,1,1,2
558538,4,1,3,3,2,1,1,1,1,2
1266124,5,1,2,1,2,1,1,1,1,2
1296025,4,1,2,1,2,1,1,1,1,2
1296263,4,1,1,1,2,1,1,1,1,2
1296593,5,2,1,1,2,1,1,1,1,2
1299161,4,8,7,10,4,10,7,5,1,4
1301945,5,1,1,1,1,1,1,1,1,2
1302428,5,3,2,4,2,1,1,1,1,2
1318169,9,10,10,10,10,5,10,10,10,4
1113061,5,1,1,1,2,1,3,1,1,2
1116192,5,1,2,1,2,1,3,1,1,2
1135090,4,1,1,1,2,1,2,1,1,2
1145420,6,1,1,1,2,1,2,1,1,2
1158157,5,1,1,1,2,2,2,1,1,2
1171578,3,1,1,1,2,1,1,1,1,2
1174841,5,3,1,1,2,1,1,1,1,2
1184586,4,1,1,1,2,1,2,1,1,2
1186936,2,1,3,2,2,1,2,1,1,2
1197527,5,1,1,1,2,1,2,1,1,2
1222464,6,10,10,10,4,10,7,10,1,4
1240603,2,1,1,1,1,1,1,1,1,2
import numpy as np
import pandas as pd
def main():
    data = pd.read_csv('new.txt')  # 數據集用逗號分隔,直接用txt,也可以讀取CSV格式的。
    data = data.values             # DataFrame類型轉換成Numpy中array類型,並把表頭去掉;
    data

if __name__ == '__main__':
    main()

創作不易,歡迎點贊轉發!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章