torch.stack()解析

torch.stack()是将原来的几个tensor按照一定方式进行堆叠,然后在按照堆叠后的维度进行切分
在这里插入图片描述

有a,b,c三个tensor.
在这里插入图片描述

dim=0

在这里插入图片描述

dim=1

在这里插入图片描述

dim=2

在这里插入图片描述
参考:torch.stack(), torch.cat()用法详解

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章