torch.stack 和 torch.cat 錯誤:argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

本篇博文介紹pytorch中一些函數的輸入問題,主要是tensor 和 tensors的區別。
在pytorch中我們也有對一個數據的疊加:
pytorch.stack ,這個函數可以在數據疊加的同時,擴展數據維度。比如說我們把三個數疊加到一起,可以組成一個二維的矩陣,得到的二維矩陣可以是[1, 2],也可以是[2, 1]。
pytorch.cat,這個函數是直接把兩個數據連接起來,維度是不變的。還是上面那個例子,把三個數疊加到一起,維度還是1,爲[2].

不過,我相信有些小夥伴對於這個兩個函數的輸入可能有一些不確定的地方。
首先,我們要知道,tensor和tensors是不同的東西。前者這裏不解釋了,後者是tensors,可以認爲是tensor的列表,或者是tensor的元組。也就是說torch.stack 和 torch.cat 的輸入是一個列表或者元組, 這個列表和元組的內容是由tensor組成的。

先看官網:torch.stack。
在這裏插入圖片描述

然後有三個例子,分別說明:

  1. 輸入爲列表,正確,狹義上有也分兩種形式: 直接輸入一個列表,或者將tensor組成列表
  2. 輸入爲元組,正確,也分兩種,直接一個元組,或者有tensor組成一個元組。
  3. 輸入爲tensor,出錯
    注意到元組和列表的區別。前者可以使用()來顯性表示,後者使用[]來顯性表示。
    在這裏插入圖片描述
    在這裏插入圖片描述
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章