pytorch裏面的torch.gather操作

pytorch裏面的torch.gather操作

This article was original written by XRBLS, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat: jintianiloveu

torch.gather 只是一個引子,別看它簡單,但能引出很多問題。我們先來看看,它是如何工作的。假如我們有一個矩陣:

[[34, 4, 6],
[45, 6, 7]]

我們想要對它每一個位置的點進行重新排列應該怎麼做呢?比如我要得到這麼一個矩陣:

[[4, 6, 34],
[6, 7, 45]]

可以看到,我把每一行的(此時的axis=1)位置進行了變換。具體來說,用torch.gather可以做這個事情:

r = torch.gather(a, 1, torch.tensor([[1, 2, 0], [1, 2, 0]]))

說白了,就是用一個矩陣來對它進行重排。那麼到底在什麼場合我們會用到這個函數呢?

其實一個很明顯的作用就是在分類問題中,通過gather方法可以從一個矩陣裏面挑選出最大值來完成分類任務。

之前有遇到一個onnx2trt的問題,但是本質上並不是由於它造成的,跟gather沒有太大的關係。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章