pytorch裏面的torch.gather操作
This article was original written by XRBLS, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat:
jintianiloveu
torch.gather
只是一個引子,別看它簡單,但能引出很多問題。我們先來看看,它是如何工作的。假如我們有一個矩陣:
[[34, 4, 6],
[45, 6, 7]]
我們想要對它每一個位置的點進行重新排列應該怎麼做呢?比如我要得到這麼一個矩陣:
[[4, 6, 34],
[6, 7, 45]]
可以看到,我把每一行的(此時的axis=1)位置進行了變換。具體來說,用torch.gather
可以做這個事情:
r = torch.gather(a, 1, torch.tensor([[1, 2, 0], [1, 2, 0]]))
說白了,就是用一個矩陣來對它進行重排。那麼到底在什麼場合我們會用到這個函數呢?
其實一個很明顯的作用就是在分類問題中,通過gather方法可以從一個矩陣裏面挑選出最大值來完成分類任務。
之前有遇到一個onnx2trt的問題,但是本質上並不是由於它造成的,跟gather沒有太大的關係。