分類問題中的決策面畫法 (直觀理解plt.contour的用法)

摘要

通過分類問題中決策面的繪製過程直觀理解matplotlib中contour的用法,主要包括對 np.meshgrid 和plt.contour的直觀理解。

前言

分類問題中,我們習慣用2維的dmeo做例子,驗證算法的有效性。直觀的評價方法是在散點圖上畫一個決策面(decision bondary)來可視化的顯示分類結果。

我們借鑑scikit learn中的一個例子,代碼如下:

# 這裏我稍微調整了下plt.contour中的參數,使得結果更好看一點
def plot_decision_boundary(model, x, y):
    x_min, x_max = x[:, 0].min() - 0.5, x[:, 0].max() + 0.5
    y_min, y_max = x[:, 1].min() - 0.5, x[:, 1].max() + 0.5
    h = 0.01
    # 繪製網格
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    # 生成與網格上所有點對應的分類結果
    z = model(np.c_[xx.ravel(), yy.ravel()])
    z = z.reshape(xx.shape)
    # 繪製contour
    plt.contour(xx, yy, z, levels=[0.5], colors=['blue'])
    plt.scatter(x[:, 0], x[:, 1], c=y)

效果(這裏我搭建了一個具有三個隱藏層的簡單神經網絡,用來對two moons dataset 進行分類):

Decision Boundary
我們總結下上述代碼,發現畫決策面可以分爲三個步驟:

  1. 在想要繪製決策面的數據範圍生成網格(xx, yy)
  2. 將網格上所有的點輸入分類器並得到輸出 (z)
  3. 通過網格(xx, yy)和輸出結果(z)繪製一條level=0.5的contour

接下來就裏面的meshgrid和contour繪製過程做一個簡單的小例子,來直觀理解上述過程。

step 1: 生成Meshgrid

我們用一個簡單的小例子直觀理解meshgrid生成了什麼。代碼如下:

In [1]: import numpy as np

In [2]: x = np.arange(1, 5, 1)

In [3]: y = np.arange(1, 3, 1)

In [4]: x
Out[4]: array([1, 2, 3, 4])

In [5]: y
Out[5]: array([1, 2])

In [6]: xx, yy = np.meshgrid(x, y)

In [7]: xx
Out[7]:
array([[1, 2, 3, 4],
       [1, 2, 3, 4]])

In [8]: yy
Out[8]:
array([[1, 1, 1, 1],
       [2, 2, 2, 2]])

In [9]: xx[0][1]
Out[9]: 2

In [10]: yy[0][1]
Out[10]: 1

分析:我們設置x=[1,2,3,4],y=[1,2]x = [1, 2, 3, 4], y = [1, 2]。發現meshgrid 生成的是兩個2×42\times 4的矩陣,並且把它們的對應位置拼接起來, 如xx([0][0],yy[0][0])=(1,1)xx([0][0], yy[0][0]) = (1, 1)(xx[0][1],yy[0][1])=(2,1)(xx[0][1], yy[0][1]) = (2, 1)就是下面的這樣一個網格:

mesh grid

Step 2: 將meshgrid 上的點輸入分類器

爲了方便,我們首先利用上述xx,yyxx, yy生成要輸入分類器的數據點。即上圖網格中的所有點對。

In [19]: xx.ravel()
Out[19]: array([1, 2, 3, 4, 1, 2, 3, 4])

In [20]: yy.ravel()
Out[20]: array([1, 1, 1, 1, 2, 2, 2, 2])

In [21]: np.c_[xx.ravel(), yy.ravel()] # 按列組合兩個矩陣
Out[21]:
array([[1, 1],
       [2, 1],
       [3, 1],
       [4, 1],
       [1, 2],
       [2, 2],
       [3, 2],
       [4, 2]])

In [22]: np.c_[xx.ravel(), yy.ravel()][0]
Out[22]: array([1, 1])

可以看到,這裏首先把xxxxyyyy拉直,然後再按列組合生成了要輸如的數據點,再輸入分類器即可。

Step 3:

contour and contourf draw contour lines and filled contours, respectively.

其參數如下:contour([X, Y,] Z, [levels], **kwargs)

可以藉助中學物理的等高線來直觀理解contour的畫法,[X, Y]是座標點,Z是每個點對應的高度,levels = 0.5的意思是我們要在z=0.5的地方畫一條等高線。

s

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章