轉載自這篇文章
本文結構:
- 學習曲線是什麼?
- 怎麼解讀?
- 怎麼畫?
學習曲線是什麼?
學習曲線就是通過畫出不同訓練集大小時訓練集和交叉驗證的準確率,可以看到模型在新數據上的表現,進而來判斷模型是否方差偏高或偏差過高,以及增大訓練集是否可以減小過擬合。
怎麼解讀?
當訓練集和測試集的誤差收斂但卻很高時,爲高偏差。
左上角的偏差很高,訓練集和驗證集的準確率都很低,很可能是欠擬合。
我們可以增加模型參數,比如,構建更多的特徵,減小正則項。
此時通過增加數據量是不起作用的。
當訓練集和測試集的誤差之間有大的差距時,爲高方差。
當訓練集的準確率比其他獨立數據集上的測試結果的準確率要高時,一般都是過擬合。
右上角方差很高,訓練集和驗證集的準確率相差太多,應該是過擬合。
我們可以增大訓練集,降低模型複雜度,增大正則項,或者通過特徵選擇減少特徵數。
理想情況是是找到偏差和方差都很小的情況,即收斂且誤差較小。
怎麼畫?
在畫學習曲線時,橫軸爲訓練樣本的數量,縱軸爲準確率。
例如同樣的問題,左圖爲我們用 naive Bayes 分類器時,效果不太好,分數大約收斂在 0.85,此時增加數據對效果沒有幫助。
右圖爲 SVM(RBF kernel),訓練集的準確率很高,驗證集的也隨着數據量增加而增加,不過因爲訓練集的還是高於驗證集的,有點過擬合,所以還是需要增加數據量,這時增加數據會對效果有幫助。
上圖的代碼如下:
模型這裏用 GaussianNB 和 SVC 做比較,
模型選擇方法中需要用到 learning_curve 和交叉驗證方法 ShuffleSplit。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.datasets import load_digits
from sklearn.model_selection import learning_curve
from sklearn.model_selection import ShuffleSplit
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
首先定義畫出學習曲線的方法,
核心就是調用了 sklearn.model_selection 的 learning_curve,
學習曲線返回的是 train_sizes, train_scores, test_scores,
畫訓練集的曲線時,橫軸爲 train_sizes, 縱軸爲 train_scores_mean,
畫測試集的曲線時,橫軸爲 train_sizes, 縱軸爲 test_scores_mean:
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):
~~~
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
train_scores_mean = np.mean(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
~~~
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
在調用 plot_learning_curve 時,首先定義交叉驗證 cv 和學習模型 estimator。
這裏交叉驗證用的是 ShuffleSplit, 它首先將樣例打散,並隨機取 20% 的數據作爲測試集,這樣取出 100 次,最後返回的是 train_index, test_index,就知道哪些數據是 train,哪些數據是 test。
estimator 用的是 GaussianNB,對應左圖:
cv = ShuffleSplit(n_splits=100, test_size=0.2, random_state=0)
estimator = GaussianNB()
plot_learning_curve(estimator, title, X, y, ylim=(0.7, 1.01), cv=cv, n_jobs=4)
- 1
- 2
- 3
- 1
- 2
- 3
再看 estimator 是 SVC 的時候,對應右圖:
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
estimator = SVC(gamma=0.001)
plot_learning_curve(estimator, title, X, y, (0.7, 1.01), cv=cv, n_jobs=4)
- 1
- 2
- 3
- 1
- 2
- 3
資料:
http://scikit-learn.org/stable/modules/learning_curve.html#learning-curve
http://scikit-learn.org/stable/auto_examples/model_selection/plot_learning_curve.html#sphx-glr-auto-examples-model-selection-plot-learning-curve-py
推薦閱讀
歷史技術博文鏈接彙總
也許可以找到你想要的