pytorch將標籤轉爲onehot

由於想多分類中使用Diceloss,所以需要將[0,1,2,..N]類型的標籤轉化爲onehot類型。

input數據類型: torch.LongTensor()

        數據形狀:[bs, 1, *]         可爲2D或3D數據      

def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [bs, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [bs, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, input.cpu(), 1)

    return result

溫馨提示:其他需要的知識:

1、FloatTensor轉化爲LongTensor:

# 此時的輸入label爲FloatTensor,可在cuda,也可是cpu
label_long = label.long()

2、 Tensor增加一個維度

label_onehot = label_onehot.unsqueeze(1)   #在第一維增加一個維度

3、多分類交叉熵是不需要將標籤轉爲onehot的

詳情請查看  https://blog.csdn.net/longshaonihaoa/article/details/105253553

4、最近版pytorch有直接的轉化爲onehot的代碼,瞭解之後更新。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章