test
import torch
nums = 3
# 首先,我們隨機生成一個3*5的矩陣
a = torch.randn(nums, 5)
# 在列的維度(dim=0),取每列的最大值
overlap_for_each_prior, object_for_each_prior = a.max(dim=0)
_, prior_for_each_object = a.max(dim=1) # (N_o)
print(a)
print(overlap_for_each_prior)
print('++++++++++++++++++')
print(object_for_each_prior)
print(prior_for_each_object)
##################################################
##################################################
'''這裏是關鍵'''
print(object_for_each_prior[prior_for_each_object])
'''這裏是關鍵'''
##################################################
##################################################
print(torch.LongTensor(range(nums)))
print('++++++++++++++++++')
# 這一步驟的操作是以prior_for_each_object爲索引,從
# object_for_each_prior取對應的值
object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(nums, 2*nums))
print('++++++++++++++++++')
print(object_for_each_prior)
print('++++++++++++++++++')
代碼解析:
# 首先,我們隨機生成一個3*5的矩陣
a = torch.randn(nums, 5)
# 在列的維度(dim=0),取每列的最大值
overlap_for_each_prior, object_for_each_prior = a.max(dim=0)
_, prior_for_each_object = a.max(dim=1) # (N_o)
這個時候,假設
a的值爲:
'''
tensor([[-0.1705, 1.2972, 1.8852, -1.0077, -0.6337],
[ 1.5984, -0.6461, 0.3798, -0.4751, 0.9754],
[ 0.7052, 0.4189, 0.1964, 1.0021, 1.6406]])
'''
那麼
overlap_for_each_prior:
'''
tensor([1.5984, 1.2972, 1.8852, 1.0021, 1.6406])
'''
object_for_each_prior:
'''
tensor([1, 0, 0, 2, 2])
'''
prior_for_each_object
'''
tensor([2, 0, 4])
'''
然後,神一樣的操作來了:
object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(nums, 2*nums))
- 首先:
'''
torch.LongTensor(range(nums, 2*nums))
'''
生成了一個Tensor數組:
'''
tensor([3, 4, 5])
'''
- 然後:
object_for_each_prior:
'''
tensor([1, 0, 0, 2, 2])
'''
在tensor([2, 0, 4])的作用下變成了(取2, 0, 4對應的索引值)
'''
tensor([0, 1, 2])
'''
然後把object_for_each_prior中2,0,4對應的值換成:3,4,5.
所以,最後的結果就是
tensor([4, 0, 3, 2, 5])