從零開始深度學習Pytorch筆記——張量的拼接與切分

本文研究張量的拼接與切分。張量的拼接import torch(1) 使用torch.cat()拼接將張量按維度dim進行拼接,不會擴張張量的維度torch.cat(tensors, dim=0, out=None)其中:tensors:張量序列dim:要拼接的維度t = torch.ones((3,2))t0 = torch.cat([t,t],dim=0)#在第0個維度上拼接t1 = torch.cat([t,t],dim=1)#在第1個維度上拼接print(t0,’\n\n’,t1)t2 = torch.cat([t,t,t],dim=0)t2(2) 使用torch.stack()拼接在新創建的維度dim上進行拼接,會擴張張量的維度torch.stack(tensors, dim=0, out=None)參數:tensors:張量序列dim:要拼接的維度t = torch.ones((3,2))t1 = torch.stack([t,t],dim=2)#在新創建的維度上進行拼接print(t1,t1.shape) #拼接完會從2維變成3維我們可以看到維度從拼接前的(3,2)變成了(3,2,2),即在最後的維度上進行了拼接!t = torch.ones((3,2))t1 = torch.stack([t,t],dim=0)#在新創建的維度上進行拼接#由於指定是第0維,會把原來的3,2往後移動一格,然後在新的第0維創建新維度進行拼接print(t1,t1.shape)t = torch.ones((3,2))t1 = torch.stack([t,t,t],dim=0)#在新創建的維度上進行拼接#由於是第0維,會把原來的3,2往後移動一格,然後在新的第0維創建新維度進行拼接print(t1,t1.shape)張量的切分(1) 使用torch.chunk()切分可以將張量按維度dim進行平均切分return 張量列表如果不能整除,最後一份張量小於其他張量torch.chunk(input, chunks, dim=0)參數:input:要切分的張量chunks:要切分的份數dim:要切分的維度a = torch.ones((5,2))t = torch.chunk(a,dim=0,chunks=2)#在5這個維度切分,切分成2個張量for idx, t_chunk in enumerate(t): print(idx,t_chunk,t_chunk.shape)可以看出後一個張量小於前一個張量的,前者第0個維度上是3,後者是2。(2) 使用torch.split()切分將張量按維度dim進行切分return:張量列表torch.split(tensor, split_size_or_sections, dim=0)參數:tensor:要切分的張量split_size_or_sections:爲int時,表示每一份的長度;爲list時,按list元素切分dim:要切分的維度a = torch.ones((5,2))t = torch.split(a,2,dim=0)#指定了每個張量的長度爲2for idx, t_split in enumerate(t): print(idx,t_split,t_split.shape)#切出3個張量a = torch.ones((5,2))t = torch.split(a,[2,1,2],dim=0)#指定了每個張量的長度爲列表中的大小【2,1,2】for idx, t_split in enumerate(t): print(idx,t_split,t_split.shape)#切出3個張量a = torch.ones((5,2))t = torch.split(a,[2,1,1],dim=0)#list中求和不爲長度則拋出異常for idx, t_split in enumerate(t): print(idx,t_split,t_split.shape)#切出3個張量RuntimeError:split_with_sizes expects split_sizes to sum exactly to 5 (input tensor’s size at dimension 0), but got split_sizes=[2, 1, 1]

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章