數據預處理 One-hot 編碼的兩種實現方式

1. 什麼是 One-hot 編碼

最直觀的理解就是,比如說現在有三個類別A、B、C,它們對應的標籤值分別爲 [1, 2, 3],如果對這三個類別使用One-hot編碼,得到的結果則是,[[1, 0, 0], [0, 1, 0], [0, 0, 1]],相當於:

  • 1 被編碼爲 1 0 0
  • 2 被編碼爲 0 1 0
  • 3 被編碼爲 0 0 1

2. 爲什麼要對數據進行 One-hot 編碼

分割任務中,網絡模型最後的輸出shape爲 [N, C, H, W] (以pytoch爲例, 其中N爲batch_size, C爲預測的類別數),而我們給的的gt(ground truth)的shape一般爲[H, W, 3](彩色圖或rgb圖)或[H, W](灰度圖)。
假設我們現在的分割任務裏面有5個目標需要分割,給定的gt是彩色的。則網絡模型最後的輸出shape爲 [N, 5, H, W],這和gt的shape不匹配,在訓練的時候它們兩者之間不能進行損失值計算。因此,就需要使用One-hot編碼對gt進行編碼,將其編碼爲[H, W, 5],最後再對維度進行transpose即可。

編碼前和編碼後的變化類似圖中所示(上圖對應編碼前,下圖對應編碼後)。
{% asset_img 1.png %}
(圖片來源:https://www.eefocus.com/communication/413211/r0)
{% asset_img 2.png %}
(圖片來源:https://www.eefocus.com/communication/413211/r0)

3.代碼實現

3.1 方法一

mask_to_onehot用來將標籤進行one-hot,onehot_to_mask用來恢復one-hot,在可視化的時候使用。

def mask_to_onehot(mask, palette):
    """
    Converts a segmentation mask (H, W, C) to (H, W, K) where the last dim is a one
    hot encoding vector, C is usually 1 or 3, and K is the number of class.
    """
    semantic_map = []
    for colour in palette:
        equality = np.equal(mask, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
    return semantic_map

def onehot_to_mask(mask, palette):
    """
    Converts a mask (H, W, K) to (H, W, C)
    """
    x = np.argmax(mask, axis=-1)
    colour_codes = np.array(palette)
    x = np.uint8(colour_codes[x.astype(np.uint8)])
    return x

方法一在使用的時候需要先定義好顏色表palette(根據自己的數據集來定義就行了)。下面演示兩個例子。

假設gt是灰度圖,需要分割兩個目標(正常器官和腫瘤)(加上背景就是3分類任務),正常器官的灰度值爲128,腫瘤的灰度值爲255, 背景的灰度值爲0。

palette = [[0], [128], [255]]  # 裏面值的順序不是固定的,可以按自己的要求來
# 注意:灰度圖的話要確保 gt的 shape = [H, W, 1],該函數實在最後的通道維上進行映射
# 如果加載後的gt的 shape = [H, W],則需要進行通道的擴維
gt_onehot = mask_to_onehot(gt, palette)  # one-hot 後 gt的shape=[H, W, 3]

假設gt彩色圖,需要分割5個目標(加上背景就是6分類任務),顏色值如下。 和灰度圖的處理方法類似。

palette = [[0, 0, 0], [192, 224, 224], [128, 128, 64], [0, 192, 128], [128, 128, 192], [128, 128, 0]]
gt_onehot = mask_to_onehot(gt, palette)  # one-hot 後 gt的shape=[H, W, 6]

3.1 方法二

爲了以示區別,名字不要起的一樣。

def mask2onehot(mask, num_classes):
    """
    Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one
    hot encoding vector

    """
    _mask = [mask == i for i in range(num_classes)]
    return np.array(_mask).astype(np.uint8)

def onehot2mask(mask):
    """
    Converts a mask (K, H, W) to (H,W)
    """
    _mask = np.argmax(mask, axis=0).astype(np.uint8)
    return _mask

用法:如果gt是灰度圖,如上面的例子,用起來就比較簡單。

# 需要先指定每個類別的顏色值對應的標籤
# 注意: 第一類從0開始,而不是從1開始
label2trainid = {0: 0, 128: 1, 255: 2}
gt_copy = gt.copy()
# 這一步相當於把
for k, v in label2trainid.items():
    gt_copy[gt == k] = v
gt_with_trainid = gt_copy.astype(np.uint8)

gt_onehot = mask2onehot(gt_with_trainid, 3) # one-hot 後 gt的shape=[3, H, W]

如果gt是彩色圖,要先把rgb顏色值映射爲標籤,再進行one-hot編碼,相對來說就比較繁瑣了。直接用方法一就行了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章