關於pytorch用tensor索引另一個tensor問題

1. 問題

在項目中關於如下代碼出現問題:

def fun1():
    start_transitions = torch.nn.Parameter(torch.empty(5))
    torch.nn.init.uniform_(start_transitions, -0.1, 0.1)
    tags_b = tags.byte()
    rs = start_transitions[tags_b[0]]
    print(rs)

出現的錯誤:

The shape of the mask [20] at index 0 does not match the shape of the indexed tensor [5] at index 0

問題:報錯原因在哪裏?

2. 關於tensor的索引問題

研究支持tensor的索引有哪些類型,運行如下:

def fun2():
    start_transitions = torch.nn.Parameter(torch.empty(5))
    torch.nn.init.uniform_(start_transitions, -0.1, 0.1)
    tags_i = tags.int()
    print(start_transitions[tags_i[0]])

出現如下的報錯:

tensors used as indices must be long, byte or bool tensors

結論:tensors的下標必須爲long或byte類型。
陷阱:long與type的作用又不一樣。

3. 關於byte類型作爲下標

 def fun_tyte_indx():
    start_transitions = torch.nn.Parameter(torch.empty(5))
    torch.nn.init.uniform_(start_transitions, -0.1, 0.1)
    tags_b = tags.byte()
    print(start_transitions[tags_b[0][:5]])
    print(start_transitions[tags_b[0]])

對於輸出,第一個沒有錯誤,已輸出來,可是最後一行出現報錯:

tensor([-0.0560, -0.0440,  0.0341, -0.0022, -0.0191], grad_fn=<IndexBackward>)
The shape of the mask [20] at index 0 does not match the shape of the indexed tensor [5] at index 0

結論:Byte類型的下標操作像是一個mask,將原有tensor進行篩選一遍,取出tensor2 對應位置不爲0的元素;

4. 關於long類型作爲下標

def fun_long_index():
    a = torch.arange(16, 30)
    print('a=', a)
    index_list = [[4, 1, 2], [2, 1, 1]]
    c = torch.LongTensor(index_list)
    print('c:', c)
    print('a[c]:', a[c])
    a = a.view(7, 2)
    print('a_7*2:', a)
    print('a[c]:', a[c])
    print(a.shape, c.shape, a[c].shape)

輸出內容:

a= tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
c: tensor([[4, 1, 2],
        [2, 1, 1]])
a[c]: tensor([[20, 17, 18],
        [18, 17, 17]])
a_7*2: tensor([[16, 17],
        [18, 19],
        [20, 21],
        [22, 23],
        [24, 25],
        [26, 27],
        [28, 29]])
a[c]: tensor([[[24, 25],
         [18, 19],
         [20, 21]],

        [[20, 21],
         [18, 19],
         [18, 19]]])
torch.Size([7, 2]) torch.Size([2, 3]) torch.Size([2, 3, 2])

從例子來看,c相當於中將所有的元素替換成a中指定位置的元素;對於多維選擇dim=0;
結論:相當於在 tensor2 中將所有的元素替換成tensor1中指定位置的元素;

5. 總結

同樣的代碼,同樣的數值類型,就是由於保存的位數不一樣,會產生了不一樣的結果。這樣對於一直以來使用高級語言數值類型會自動轉思維定勢要帶來一些未知bug產生。需三思。

[happyprince] https://blog.csdn.net/ld326/article/details/105114212

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章