由於想多分類中使用Diceloss,所以需要將[0,1,2,..N]類型的標籤轉化爲onehot類型。
input數據類型: torch.LongTensor()
數據形狀:[bs, 1, *] 可爲2D或3D數據
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [bs, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [bs, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(1, input.cpu(), 1)
return result
溫馨提示:其他需要的知識:
1、FloatTensor轉化爲LongTensor:
# 此時的輸入label爲FloatTensor,可在cuda,也可是cpu
label_long = label.long()
2、 Tensor增加一個維度
label_onehot = label_onehot.unsqueeze(1) #在第一維增加一個維度
3、多分類交叉熵是不需要將標籤轉爲onehot的
詳情請查看 https://blog.csdn.net/longshaonihaoa/article/details/105253553
4、最近版pytorch有直接的轉化爲onehot的代碼,瞭解之後更新。