【調參7】不平衡分類問題中分類權重計算與設置



代碼環境:

  • python-3.7.6
  • tensorflow-2.1.0

前言

最近幾個月一直在做時間序列分類相關的工作,在實際應用工作中,調整模型參數對模型的性能表現的影響比較大。通過設置分類權重平衡原來數據集中樣本分佈不均衡的情況,同時配合Adam優化算法和BatchNormalization,準確率有比較大的提升。

本文重點介紹【時間序列分類任務】中【不平衡分類】問題,如何配置 tensorflow.keras.model.fit() API 中的 class_weight 參數,以平衡分類。其實在時間序列分類任務中,也可以通過插值的方式擴充數據樣本,使各類樣本數量相同,來平衡各類別。

不平衡分類是個比較大的問題,以後再寫吧。本文先介紹如何在時間序列分類任務中的應用。


1. 類別權重如何計算

一開始看到分類平衡中的分類權重參數(tensorflow.keras.model.fit中的參數爲 class_weight)的計算時比較懵,不知道如何計算的,大部分文章基本都沒有提及計算過程,不適合我這樣的小白。於是,動手算了算,發現其實很簡單。

class_weight 參數需要傳入一個各類別權重字典,通過分類數量計算得到,其的意義是模型對各類別的關注程度。 注意與 sample_weight 區分。

如果要平衡分類,一種簡單的方式是通過設置該參數,讓模型重點關注樣本數目較少的類別,即乘以不同的權重,達到類似各類樣本數目都相同的目的。


計算的方式有很多種,舉個例子說明如何計算分類權重。

假設有A,B,C,D四類樣本,各類樣本數量如下:

分類:A	  	 B	      C	  	   D       total
數量:10  	 20       30       40	   100
比例:1/10   2/10     3/10    4/10     10/10

假設平衡後模型對各類別的關注度相同,即每類樣本應該受到相同的關注,取個名字,稱它爲平均關注度吧(隨便起的,方便表示)即 :
=1=14()平均關注度=\frac{1}{分類數} =\frac{1}{4}(本例中)

那麼,每類樣本的比例乘以各自的權重,應該等於平均關注度。easy!則有:
=某類樣本的權重 * 樣本所佔數據集比例 = 平均關注度

如此可以求出某類樣本平衡後應該乘以的權重:
==1某類樣本的權重 = \frac{平均關注度}{樣本所佔數據集比例}=\frac{1}{分類數}*\frac{樣本總數}{某類樣本數量}

簡單的推導。。。想明白轉化關係其實很簡單。清楚了計算公式,下面講一下如何實現。


2. tensorflow.keras.model.fit API 配置

fit(
    x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,
    ...
    class_weight=None,
    sample_weight=None, 
    ...
)

關鍵參數說明:

  • class_weight:可選參數。樣本標籤(整型,已經編碼)到權重值(float類型)的映射,用來加權損失函數(僅在訓練期間加權)。這可以有效地告訴模型需要更關注哪些樣本。
  • sample_weight:可選參數。訓練樣本權重數組,用來加權損失函數(僅在訓練期間加權)。

本文僅介紹如何配置 class_weight 參數。


3. 實現方法

3.1 數據集介紹

本文使用的數據集是在個人創建的適合業務場景的數據集。包含十個類別,以7:3劃分訓練集和驗證集。訓練集中各樣本數量如下:

類別 	數量
1		582
2		580
3		1688
4		1152
5		580
6		580
7		2470
8		571
9		534
10		578
total	9315

在本文的數據集中,原數據存儲在csv文件中,各類標籤已經編碼爲單個整型數字。

首先,以 pandas.DataFrame 格式讀取,滑動窗口切分完成後,提取各樣本已經編碼的標籤。其實直接從原始數據集(還未使用滑動窗口)提取標籤計算是一樣的,相差不會太多,無非是滑動窗口可能丟棄了不完整的數據,實際影響不大,這也是設置分類權重而不設置樣本權重的好處,一勞永逸~。如果在滑動窗口處理數據的過程中,標籤已經轉換成 one-hot 編碼,使用 np.argmax(trainy,axis=1) 轉換成單個整型數字編碼即可。

下面看一下代碼實現。


3.2 代碼實現

有了以上準備,現在應該有:單個整型數字組成的標籤序列。比如:

y_train
-> array([6, 6, 6, ..., 3, 3, 3], dtype=int64)
len(y_train)
-> 9315

現在進一步處理:

  1. 提取樣本標籤:
import numpy as np
classes = np.unique(y_train)
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64)
  1. 初始化標籤編碼API
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y_index = le.fit_transform(y_train.ravel()) # np.ravel()返回包含輸入元素的一維數組
y_index
array([6, 6, 6, ..., 3, 3, 3], dtype=int64)
  1. 計算總樣本數
len(y_train)
9315
  1. 計算總類別數
len(classes)
10
  1. 計算各類別樣本數量
np.bincount(y_index) # np.bincount()計算非負整數數組中每個值的出現次數。
array([ 582,  580, 1688, 1152,  580,  580, 2470,  571,  534,  578],
      dtype=int64)
  1. 標準化標籤編碼
le.transform(classes)
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64)

這裏說明一下,跟np.unique() 輸出相同,np.unique() 是取出數組中的唯一元素,並返回包含這些元素的數組。le.transform() 標準化編碼,這一步是保證輸入數據是真標籤的情況下,轉換爲數字編碼。

  1. 計算權重
recip_freq = len(y_train) / (len(classes) * np.bincount(y_index).astype(np.float64))
recip_freq
array([1.60051546, 1.60603448, 0.55183649, 0.80859375, 1.60603448,
       1.60603448, 0.37712551, 1.63134851, 1.74438202, 1.6115917 ])
  1. 映射
class_weight = recip_freq[le.transform(classes)]
class_weight
array([1.60051546, 1.60603448, 0.55183649, 0.80859375, 1.60603448,
       1.60603448, 0.37712551, 1.63134851, 1.74438202, 1.6115917 ])

3.3 完整代碼

如果訓練標籤已經轉化爲one-hot編碼:

y_train = np.argmax(trainy, axis=1)

如果使用原始數據集處理:

y_train = train_loaded['state_encode'] # pd.series

完整代碼:

from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()

y_index = le.fit_transform(y_train.ravel())
recip_freq = len(y_train) / (len(classes) * np.bincount(y_index).astype(np.float64))
class_weight = recip_freq[le.transform(classes)]
print("Class weights : ", class_weight)

# 將訓練標籤變爲one-hot編碼;這裏使用keras API 實現,使用pd.get_dummies也可以。
y_train = tf.keras.utils.to_categorical(y_train, len(np.unique(y_train)))

y_train:

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

相關API 官方文檔

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章