pytorch 分割一个tensor并求平均

pytorch的torch.add()torch.split()函数

import torch
# outputs是一个[batch, seq, 40]维的tensor,把outputs分割成两个[batch, seq, 20]的tensor,并每个元素求平均值
add = torch.add(*torch.split(outputs, 20, dim=2)) / 2
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章