torch.stack()的官方解释,详解以及例子

可以直接看最下面的例子,再回头看前面的解释

pytorch中,常见的拼接函数主要是两个,分别是:

  1. stack()
  2. cat()

实际使用中,这两个函数互相辅助:关于cat()参考torch.cat(),但是本文主要说stack()

函数的意义:使用stack是为了保留两个信息:[1. 序列(先后)] 和 [2. 张量矩阵] 信息,而存在的扩张拼接函数。 常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。

1. stack()

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠

outputs = torch.stack(inputs, dim=0) → Tensor

参数

  • inputs : 待连接的张量序列。
    注:python的序列数据只有listtuple

  • dim : 新的维度, 必须在0len(outputs)之间。
    注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。

2. 重点

  1. 函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等

----举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape

  1. dim是选择生成的维度,必须满足0<=dim<len(outputs)len(outputs)是输出后的tensor的维度大小

不懂的看例子,再回过头看就懂了。

3. 例子

按下面的三步:准备数据,合成inputs,查看结果。

1.准备数据,每个的shape都是[2,3]

# x1
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape  # torch.Size([2, 3])
# x3
x3 = torch.tensor([[13,23,33],[23,33,43]],dtype=torch.int)
x3.shape # torch.Size([2,3])
# x4
x4 = torch.tensor([[14,24,34],[24,34,44]],dtype=torch.int)
x4.shape # torch.Size([2,3])

2.合成inputs

'inputs为4个形状为[2 , 3]的矩阵 '
inputs = [x1, x2, x3, x4]
print(inputs)
# 打印看看结构。是4个张量
[tensor([[11, 21, 31],
         [21, 31, 41]], dtype=torch.int32),
 tensor([[12, 22, 32],
         [22, 32, 42]], dtype=torch.int32),
 tensor([[13, 23, 33],
         [23, 33, 43]], dtype=torch.int32),
 tensor([[14, 24, 34],
         [24, 34, 44]], dtype=torch.int32)]

3.查看结果, 测试不同的dim拼接结果

'选择的 0<=dim<len()前三个 增加了新维度'
In    [1]: torch.stack(inputs, dim=0).shape
Out[1]: torch.Size([4, 2, 3]) 

In    [2]: torch.stack(inputs, dim=1).shape
Out[2]: torch.Size([2, 4, 3])

In    [3]: torch.stack(inputs, dim=2).shape
Out[3]: torch.Size([2, 3, 4])
'选择的dim>len(outputs),所以报错'
In    [4]: torch.stack(inputs, dim=3).shape
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

大家可以复制代码运行一下就会发现:这个拼接后的维度大小4根据不同的dim一直变化。

dim shape
0 [4, 2, 3]
1 [2, 4, 3]
2 [2, 3, 4]
3 溢出报错

4. 总结

  1. 函数作用:
    函数stack()序列数据内部的张量进行扩维拼接,指定维度由我们选择、大小是生成后数据的维度区间。

  2. 存在意义:
    在自然语言处理和卷及神经网络中, 通常为了保留–[序列(先后)信息] 和 [张量的矩阵信息] 才会使用stack

研究自然语言处理的同学一般知道,在循环神经网络中,网络的输出数据通常是:包含了n个数据大小[batch_size, num_outputs]list,这个和[n, batch_size, num_outputs]是完全不一样的!!!!不利于计算,需要使用stack进行拼接,保留–[1.时间步]和–[2.张量的矩阵乘积属性]。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章