pytorch中謎一樣的索引賦值法

test

import torch
nums = 3
# 首先,我們隨機生成一個3*5的矩陣
a = torch.randn(nums, 5)
# 在列的維度(dim=0),取每列的最大值
overlap_for_each_prior, object_for_each_prior = a.max(dim=0)
_, prior_for_each_object = a.max(dim=1)  # (N_o)
print(a)
print(overlap_for_each_prior)
print('++++++++++++++++++')
print(object_for_each_prior)
print(prior_for_each_object)
##################################################
##################################################
                 '''這裏是關鍵'''  
print(object_for_each_prior[prior_for_each_object])
                 '''這裏是關鍵'''  
##################################################
##################################################

print(torch.LongTensor(range(nums)))
print('++++++++++++++++++')
# 這一步驟的操作是以prior_for_each_object爲索引,從
# object_for_each_prior取對應的值

object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(nums, 2*nums))

print('++++++++++++++++++')
print(object_for_each_prior)
print('++++++++++++++++++')

代碼解析:

# 首先,我們隨機生成一個3*5的矩陣
a = torch.randn(nums, 5)
# 在列的維度(dim=0),取每列的最大值
overlap_for_each_prior, object_for_each_prior = a.max(dim=0)
_, prior_for_each_object = a.max(dim=1)  # (N_o)

這個時候,假設

a的值爲:
'''
tensor([[-0.1705,  1.2972,  1.8852, -1.0077, -0.6337],
        [ 1.5984, -0.6461,  0.3798, -0.4751,  0.9754],
        [ 0.7052,  0.4189,  0.1964,  1.0021,  1.6406]])
'''

那麼

overlap_for_each_prior:
'''
tensor([1.5984, 1.2972, 1.8852, 1.0021, 1.6406])
'''

object_for_each_prior:
'''
tensor([1, 0, 0, 2, 2])
'''

prior_for_each_object 
'''
tensor([2, 0, 4])
'''

然後,神一樣的操作來了:

object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(nums, 2*nums))
  1. 首先:
'''
torch.LongTensor(range(nums, 2*nums))
'''
生成了一個Tensor數組:
'''
tensor([3, 4, 5])
'''
  1. 然後:
object_for_each_prior:
'''
tensor([1, 0, 0, 2, 2])
'''
在tensor([2, 0, 4])的作用下變成了(取2, 0, 4對應的索引值)
'''
tensor([0, 1, 2])
'''
然後把object_for_each_prior中2,0,4對應的值換成:3,4,5.
所以,最後的結果就是
tensor([4, 0, 3, 2, 5])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章