5種常用的交叉驗證技術,保證評估模型的穩定性

 

Kaggle的數據科學黑客大會最有趣和最具挑戰性的一件事是:在公共和私有的排行榜中,努力保持同樣的排名。當我的結果在一個私有的排行榜進行驗證時,我就失去了共有的排名。

 

你有沒有想過是什麼原因導致了這些排名的高差異?換句話說,爲什麼一個模型在私有排行榜上評估時會失去穩定性?

在本文中,我們將討論可能的原因。我們還將學習交叉驗證和執行它的各種方法。

模型的穩定性?

總是需要驗證你的機器學習模型的穩定性。換句話說,你不能把這個模型與你的訓練數據相匹配,並預測它的未來日期,然後希望它每次都能準確地給出結果。我之所以強調這一點是因爲每次模型預測未來的日期,它都是基於看不見的數據,這些數據可能與訓練數據不同。如果訓練模型不能從你的訓練數據中捕捉趨勢,那麼它可能會在測試集上過度擬合或不擬合。換句話說,可能會有高的方差或偏差。

讓我們通過一個例子來進一步瞭解模型的穩定性。

在這個例子中,我們試圖找出一個人購買汽車與否的關係,這取決於他的收入。爲此,我們採取了以下步驟:

我們用一個線性方程建立了買車與否和個人收入之間的關係。假設你有2010年到2019年的數據,並試圖預測2020年。您已經根據可用的列車數據訓練了您的模型。

 

在第一個圖中,我們可以說,該模型捕捉到了訓練數據的每一個趨勢,包括噪音。該模型的精度非常高,誤差極小。這被稱爲過擬合,因爲模型已經考慮了數據點的每一個偏差(包括噪聲),而且模型太敏感,只能捕獲當前數據集中出現的每一個模式。正是由於這個原因,可能會產生高偏差。

在第二個圖中,我們只是找到了兩個變量之間的最優關係,即低訓練誤差和更一般化的關係。

在第三個圖中,我們發現該模型在列車數據上表現不佳,精度較低,誤差%較大。因此,這種模式不會有很好的表現。這是不合適的典型例子。在這種情況下,我們的模型無法捕捉訓練數據的潛在趨勢。

在Kaggle的許多機器學習比賽中常見的做法是在不同的模型上進行迭代,以尋找一個性能更好的模型。然而,很難區分分數的提高是因爲我們更好地捕捉了變量之間的關係,還是我們只是過度擬合了訓練數據。爲了更多地瞭解這一點,機器學習論壇上的許多人使用了各種驗證技術。這有助於實現更一般化的關係,並維護模型的穩定性。

交叉驗證是什麼?

交叉驗證是一種在機器學習中用於評估機器學習模型性能的統計驗證技術。它使用數據集的子集,對其進行訓練,然後使用未用於訓練的數據集的互補子集來評估模型的性能。它可以保證模型正確地從數據中捕獲模式,而不考慮來自數據的干擾。

交叉驗證使用的標準步驟:

  • 它將數據集分爲訓練和測試兩部分。
  • 它在訓練數據集上訓練模型。
  • 它在測試集中評估相同的模型。
  • 交叉驗證技術可以有不同的風格。

交叉驗證中使用的各種方法

TrainTestSplit

 

這是一種基本的交叉驗證技術。在這種技術中,我們使用數據的一個子集作爲模型訓練的訓練數據,並在另一組被稱爲測試集的數據上評估模型的性能,如圖所示。誤差估計然後告訴我們的模型在看不見的數據或測試集上的表現。這是一種簡單的交叉驗證技術,也被稱爲驗證方法。這種技術存在差異大的問題。這是因爲不確定哪些數據點會出現在測試集或訓練集&這會導致巨大的方差,而且不同的集合可能會產生完全不同的結果。

n次交叉驗證/ k次交叉驗證

 

總有需要大量的數據來訓練模型,將測試數據集的一部分可以離開不理解的模型數據的模式可能會導致錯誤,也可能導致增加欠擬合模型的測試數據。爲了克服這個問題,有一種交叉驗證技術,它爲模型的訓練提供了充足的數據,也爲驗證留下了充足的數據。K摺疊交叉驗證正是這樣做的。

n次交叉驗證涉及的步驟:

  1. 基於N- fold分割你的整個數據集。
  2. 對於數據集中的每n次摺疊,在數據集的N-1次摺疊上構建模型。然後,對模型進行檢驗,檢驗n次摺疊的有效性
  3. 在預測中記錄每次迭代的錯誤。重複這個步驟,直到每一個n -fold都作爲測試集
  4. 你的N個記錄錯誤的平均值被稱爲交叉驗證錯誤,它將作爲模型的性能度量。

例如:

假設數據有100個數據點。基於這100個數據點,你想預測下一個數據點。然後可以使用100條記錄進行交叉驗證。假設摺疊次數(N) = 10。

  1. 100個數據點被分成10個桶,每個桶有10條記錄。
  2. 在這裏,根據數據和N值創建了10個摺疊。現在,在10次摺疊中,9次摺疊會被用作你的訓練數據並在10次摺疊
  3. 測試你的模型。迭代這個過程,直到每次摺疊都成爲您的測試。計算你在所有摺疊上選擇的度規的平均值。這個度量將有助於更好地一般化模型,並增加模型的穩定性。

交叉驗證(LOOCV)

在這種方法中,我們將現有數據集中的一個數據點放在一邊,並在其餘數據上訓練模型。這個過程迭代,直到每個數據點被用作測試集。這也有它的優點和缺點。讓我們來看看它們:

 

我們利用所有的數據點,因此偏差會很低

我們根據數據集中可用的數據點的數量重複n次交叉驗證過程,這會導致更高的執行時間和更高的計算量。

由於我們只對一個數據點進行測試,如果該測試數據點是一個離羣點,可能會導致較高的誤差%,因此我們不能基於這種技術對模型進行推廣。

分層n倍交叉驗證

在某些情況下,數據可能有很大的不平衡。對於這類數據,我們使用了不同的交叉驗證技術,即分層n次交叉驗證,即每一次交叉驗證都是平衡的,並且包含每個類的樣本數量大致相同。分層是一種重新安排數據的技術,以確保每一褶都能很好地代表數據中出現的所有類。

例如,在關於個人收入預測的dataset中,可能有大量的人低於或高於50K。最好的安排總是使數據在每個摺疊中包含每個類的幾乎一半實例。

Cross Validation for time series

將時間序列數據隨機分割爲摺疊數是行不通的,因爲這種類型的數據是依賴於時間的。對這類數據的交叉驗證應該跨時間進行。對於一個時間序列預測問題,我們採用以下方法進行交叉驗證。

時間序列交叉驗證的摺疊以向前鏈接的方式創建。

 

例如,假設我們有一個時間序列,顯示了一家公司2014年至2019年6年間的年汽車需求。摺疊的創建方式如下:

Train 1— [2014]

Test 1— [2015]

Train2–[2014,2015]

Test2 — [2016]….so on

 

我們逐步地選擇一個新的列車和測試集。我們選擇一個列車集,它具有最小的觀測量來擬合模型。逐步地,我們在每個摺疊中改變我們的列車和測試集。

總結

在本文中,我們討論了過擬合、欠擬合、模型穩定性和各種交叉驗證技術,以避免過擬合和欠擬合。我們還研究了不同的交叉驗證技術,如驗證方法、LOOCV、n次交叉驗證、n次分層驗證等等。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章