【调参7】不平衡分类问题中分类权重计算与设置



代码环境:

  • python-3.7.6
  • tensorflow-2.1.0

前言

最近几个月一直在做时间序列分类相关的工作,在实际应用工作中,调整模型参数对模型的性能表现的影响比较大。通过设置分类权重平衡原来数据集中样本分布不均衡的情况,同时配合Adam优化算法和BatchNormalization,准确率有比较大的提升。

本文重点介绍【时间序列分类任务】中【不平衡分类】问题,如何配置 tensorflow.keras.model.fit() API 中的 class_weight 参数,以平衡分类。其实在时间序列分类任务中,也可以通过插值的方式扩充数据样本,使各类样本数量相同,来平衡各类别。

不平衡分类是个比较大的问题,以后再写吧。本文先介绍如何在时间序列分类任务中的应用。


1. 类别权重如何计算

一开始看到分类平衡中的分类权重参数(tensorflow.keras.model.fit中的参数为 class_weight)的计算时比较懵,不知道如何计算的,大部分文章基本都没有提及计算过程,不适合我这样的小白。于是,动手算了算,发现其实很简单。

class_weight 参数需要传入一个各类别权重字典,通过分类数量计算得到,其的意义是模型对各类别的关注程度。 注意与 sample_weight 区分。

如果要平衡分类,一种简单的方式是通过设置该参数,让模型重点关注样本数目较少的类别,即乘以不同的权重,达到类似各类样本数目都相同的目的。


计算的方式有很多种,举个例子说明如何计算分类权重。

假设有A,B,C,D四类样本,各类样本数量如下:

分类:A	  	 B	      C	  	   D       total
数量:10  	 20       30       40	   100
比例:1/10   2/10     3/10    4/10     10/10

假设平衡后模型对各类别的关注度相同,即每类样本应该受到相同的关注,取个名字,称它为平均关注度吧(随便起的,方便表示)即 :
=1=14()平均关注度=\frac{1}{分类数} =\frac{1}{4}(本例中)

那么,每类样本的比例乘以各自的权重,应该等于平均关注度。easy!则有:
=某类样本的权重 * 样本所占数据集比例 = 平均关注度

如此可以求出某类样本平衡后应该乘以的权重:
==1某类样本的权重 = \frac{平均关注度}{样本所占数据集比例}=\frac{1}{分类数}*\frac{样本总数}{某类样本数量}

简单的推导。。。想明白转化关系其实很简单。清楚了计算公式,下面讲一下如何实现。


2. tensorflow.keras.model.fit API 配置

fit(
    x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,
    ...
    class_weight=None,
    sample_weight=None, 
    ...
)

关键参数说明:

  • class_weight:可选参数。样本标签(整型,已经编码)到权重值(float类型)的映射,用来加权损失函数(仅在训练期间加权)。这可以有效地告诉模型需要更关注哪些样本。
  • sample_weight:可选参数。训练样本权重数组,用来加权损失函数(仅在训练期间加权)。

本文仅介绍如何配置 class_weight 参数。


3. 实现方法

3.1 数据集介绍

本文使用的数据集是在个人创建的适合业务场景的数据集。包含十个类别,以7:3划分训练集和验证集。训练集中各样本数量如下:

类别 	数量
1		582
2		580
3		1688
4		1152
5		580
6		580
7		2470
8		571
9		534
10		578
total	9315

在本文的数据集中,原数据存储在csv文件中,各类标签已经编码为单个整型数字。

首先,以 pandas.DataFrame 格式读取,滑动窗口切分完成后,提取各样本已经编码的标签。其实直接从原始数据集(还未使用滑动窗口)提取标签计算是一样的,相差不会太多,无非是滑动窗口可能丢弃了不完整的数据,实际影响不大,这也是设置分类权重而不设置样本权重的好处,一劳永逸~。如果在滑动窗口处理数据的过程中,标签已经转换成 one-hot 编码,使用 np.argmax(trainy,axis=1) 转换成单个整型数字编码即可。

下面看一下代码实现。


3.2 代码实现

有了以上准备,现在应该有:单个整型数字组成的标签序列。比如:

y_train
-> array([6, 6, 6, ..., 3, 3, 3], dtype=int64)
len(y_train)
-> 9315

现在进一步处理:

  1. 提取样本标签:
import numpy as np
classes = np.unique(y_train)
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64)
  1. 初始化标签编码API
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y_index = le.fit_transform(y_train.ravel()) # np.ravel()返回包含输入元素的一维数组
y_index
array([6, 6, 6, ..., 3, 3, 3], dtype=int64)
  1. 计算总样本数
len(y_train)
9315
  1. 计算总类别数
len(classes)
10
  1. 计算各类别样本数量
np.bincount(y_index) # np.bincount()计算非负整数数组中每个值的出现次数。
array([ 582,  580, 1688, 1152,  580,  580, 2470,  571,  534,  578],
      dtype=int64)
  1. 标准化标签编码
le.transform(classes)
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64)

这里说明一下,跟np.unique() 输出相同,np.unique() 是取出数组中的唯一元素,并返回包含这些元素的数组。le.transform() 标准化编码,这一步是保证输入数据是真标签的情况下,转换为数字编码。

  1. 计算权重
recip_freq = len(y_train) / (len(classes) * np.bincount(y_index).astype(np.float64))
recip_freq
array([1.60051546, 1.60603448, 0.55183649, 0.80859375, 1.60603448,
       1.60603448, 0.37712551, 1.63134851, 1.74438202, 1.6115917 ])
  1. 映射
class_weight = recip_freq[le.transform(classes)]
class_weight
array([1.60051546, 1.60603448, 0.55183649, 0.80859375, 1.60603448,
       1.60603448, 0.37712551, 1.63134851, 1.74438202, 1.6115917 ])

3.3 完整代码

如果训练标签已经转化为one-hot编码:

y_train = np.argmax(trainy, axis=1)

如果使用原始数据集处理:

y_train = train_loaded['state_encode'] # pd.series

完整代码:

from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()

y_index = le.fit_transform(y_train.ravel())
recip_freq = len(y_train) / (len(classes) * np.bincount(y_index).astype(np.float64))
class_weight = recip_freq[le.transform(classes)]
print("Class weights : ", class_weight)

# 将训练标签变为one-hot编码;这里使用keras API 实现,使用pd.get_dummies也可以。
y_train = tf.keras.utils.to_categorical(y_train, len(np.unique(y_train)))

y_train:

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

相关API 官方文档

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章