[pytorch] 二分類交叉熵逆樣本頻率權重

通常,由於類別不均衡,需要使用weighted cross entropy loss平衡。

def inverse_freq(label):
	"""
	輸入label [N,1,H,W],1是channel數目
	"""
    den = label.sum() # 0
    _,_,h,w= label.shape
    num = h*w
    alpha = den/num # 0
    return torch.tensor([alpha, 1-alpha]).cuda()

# train
...
loss1 = F.cross_entropy(out1, label.squeeze(1).long(), weight=inverse_freq(label))

代碼比較簡單,寫在博客上保存。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章