pytorch one_hot

one_hot = torch.nn.functional.one_hot(torch.arange(3), num_classes=5)

print--->tensor([[1, 0, 0, 0, 0],
                 [0, 1, 0, 0, 0],
                 [0, 0, 1, 0, 0]])


one_hot = torch.nn.functional.one_hot(torch.LongTensor([1,3,4]), num_classes=5)

print--->tensor([[0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0],
                [0, 0, 0, 0, 1]])

Parameters
tensor (LongTensor) – class values of any shape.

num_classes (int) – Total number of classes. If set to -1, the number of classes will be inferred as one greater than the largest class value in the input tensor.

Returns
LongTensor that has one more dimension with 1 values at the index of last dimension indicated by the input, and 0 everywhere else.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章