Pytorch索引与切片以及维度变换! |
文章目录
一. 索引与切片
1.1. index
import torch
a = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a[0, 0].shape)
print(a[0, 0, 2, 4])
- 运行结果
torch.Size([4, 3, 28, 28])
torch.Size([28, 28])
tensor(0.0166)
1.2. select first/last N
import torch
a = torch.rand(4, 3, 28, 28)
print(a[:2].shape)
print(a[:2, :, :, :].shape) # 意思一样的
print(a[:2, :1, :, :].shape) # 意思一样的
print(a[:2, 1:, :, :].shape)
print(a[:2, -1:, :, :].shape)
- 运行结果
torch.Size([2, 3, 28, 28])
torch.Size([2, 3, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 2, 28, 28])
torch.Size([2, 1, 28, 28])
1.3. select by steps
这里: 其实只有一种通用的形式:
start : end : step
。
import torch
a = torch.rand(4, 3, 28, 28)
print(a[:, :, 0:28:2, 0:28:2].shape) # 意思一样的
print(a[:, :, ::2, ::2].shape) # 意思一样的
- 运行结果
torch.Size([4, 3, 14, 14])
torch.Size([4, 3, 14, 14])
1.4. select by specific index
这里:
a.index_select(0, torch.tensor([0, 2])).shape
第一参数表示对哪个维度进行采样,第二个参数必须为Tensor。
import torch
a = torch.rand(4, 3, 28, 28)
# 当前对图片张数进行操作,所以给了0;我们采用第0张和第2张
print(a.index_select(0, torch.tensor([0, 2])).shape)
# 第一个维度就是RGB通道。现在取G通道,B通道。
print(a.index_select(1, torch.tensor([1, 2])).shape)
- 运行结果
torch.Size([2, 3, 28, 28])
torch.Size([4, 2, 28, 28])
Process finished with exit code 0
1.5. 符号…
这里: 符号…代表任意多的维度。
import torch
a = torch.rand(4, 3, 28, 28)
print(a[...].shape)
print(a[0, ...].shape) # 这里的...等价于下面的:, :, :,
print(a[0].shape) # 等价的。
print(a[0, :, :, :].shape)
- 运行结果
torch.Size([4, 3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([3, 28, 28])
Process finished with exit code 0
1.6. select by mask(注意现shape和原没关系)
.masked_select()
import torch
a = torch.randn(3, 4)
print(a)
mask = a.ge(0.5) # ge表示大于等于0.5的数
print(mask)
print('===================================================')
aa = torch.masked_select(a, mask) # 把大于等于0.5的元素取出来,注意变成向量了,不是矩阵了。
print(aa)
- 运行结果
tensor([[ 1.0573, -0.1439, 3.9749, -0.1838],
[-0.4314, 0.0450, -1.0288, -0.5063],
[-2.8153, 1.4134, 1.5503, -0.5849]])
tensor([[ True, False, True, False],
[False, False, False, False],
[False, True, True, False]])
===================================================
tensor([1.0573, 3.9749, 1.4134, 1.5503])
Process finished with exit code 0
二. 维度变换
- 常用的API
2.1. view和reshape(维度改变)
这里: 这两个API几乎是一模一样的。Pytorch版本是0.3的时候用的函数是view函数,为了和Numpy一致在Pytorch0.4以后,增加了reshape函数。
import torch
a = torch.rand(4, 1, 28, 28) # 假设表示照片。
print(a.shape)
print(a.view(4, 28*28).shape) # 这三个是通用的
print(a.reshape(4, 28*28).shape) # 这三个是通用的
print(torch.reshape(a, [4, 28*28]).shape) # 这三个是通用的
- 运行结果
torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
torch.Size([4, 784])
torch.Size([4, 784])
Process finished with exit code 0
2.2. squeeze和unsqueeze(维度增加/减少)
这里: 这两个squeeze(如何挤压掉一个维度,默认删除维度的size为1的)和unsqueeze(一个维度如何展开)分别表示挤压,和展开的意思。
- 例子1
import torch
a = torch.rand(4, 1, 28, 28) # 假设表示照片。
print(a.shape)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(-1).shape)
print(a.squeeze().shape) # 把1的维度去掉
- 运行结果
torch.Size([4, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 28, 28])
Process finished with exit code 0
- 例子2
import torch
a = torch.tensor([1.2, 2.3])
print(a)
print(a.unsqueeze(-1))
print(a.unsqueeze(0))
- 运行结果
tensor([1.2000, 2.3000])
tensor([[1.2000],
[2.3000]])
tensor([[1.2000, 2.3000]])
Process finished with exit code 0
- 例子3:图片操作实实在在的案例。
import torch
# bias相当于给每个channel上的所有像素增加一个偏置
b = torch.rand(32) # 这里表示bias
f = torch.rand(4, 32, 14, 14)
# 如何把f叠加到b上面呢,首先维度不一样的。插入维度。
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.shape) # torch.Size([1, 32, 1, 1])
# 扩张成 [4, 32, 14, 14] 后面讲解。
2.3. expand和repeat(维度扩展)
import torch
# bias相当于给每个channel上的所有像素增加一个偏置
b = torch.rand(32) # 这里表示bias
f = torch.rand(4, 32, 14, 14)
# 如何把f叠加到b上面呢,首先维度不一样的。插入维度。
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.shape) # torch.Size([1, 32, 1, 1])
print('===================== = 扩张 ======================== ')
# 扩张成 [4, 32, 14, 14] 后面讲解。
c = b.expand(4, 32, 32, 14)
print(c.shape)
print(b.expand(-1, 32, -1, -1).shape) # -1保持不变的。
- 运行结果
torch.Size([1, 32, 1, 1])
===================== = 扩张 ========================
torch.Size([4, 32, 32, 14])
torch.Size([1, 32, 1, 1])
Process finished with exit code 0
这里: repeat函数和expand函数不一致,其参数里面给出的不像expand参数是新的维度。repeat给出的而是原来的维度要复制repeat的次数,repeat和matlab中的保持一致,不建议使用这个api,主要是占用内存,它重新申请一片内存空间,赋值新的数据。
import torch
a = torch.rand([1, 32, 1, 1])
print(a.shape)
print(a.repeat(4, 32, 1, 1).shape)
print('===========================')
print(a.repeat(4, 1, 1, 1).shape) # repeat和matlab中是一致的
- 运行结果
torch.Size([1, 32, 1, 1])
torch.Size([4, 1024, 1, 1])
===========================
torch.Size([4, 32, 1, 1])
Process finished with exit code 0
2.4. .t方法矩阵的转置以及更通用的transpose
这里: transpose()函数表示矩阵的维度交换,接受的参数为要交换的哪两个维度。
import torch
a = torch.rand([3, 3])
print(a)
print(a.t())
print('===================================')
aa = torch.rand(4, 3, 32, 32) # 【b,c,h,w】比如交换1 3维度,【b,w,h,c】
b = aa.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 3, 32, 32)
c = aa.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 32, 32, 3).transpose(1, 3)
# 判断数据和原来的是否一致
print(torch.all(torch.eq(aa, b))) # 不一致
print(torch.all(torch.eq(aa, c))) # 一致的
- 运行结果
tensor([[0.0356, 0.4528, 0.7141],
[0.8490, 0.4429, 0.8567],
[0.3274, 0.6975, 0.7669]])
tensor([[0.0356, 0.8490, 0.3274],
[0.4528, 0.4429, 0.6975],
[0.7141, 0.8567, 0.7669]])
===================================
tensor(False)
tensor(True)
Process finished with exit code 0
2.5. permute
这里: transpose()函数一次只能两两交换。【b, c, h, w】=> 【b, w, h, c】,比如原来一个人的图片,交换过后图片可能不是人了,我们还希望变成原来的样子,可以看成多维度交换,其中参数为新的维度顺序。
同样的道理permute函数也会把内存的顺序给打乱,因此要是涉及contious这个错误的时候,需要额外添加.contiguous()函数,来把内存的顺序变得连续。
import torch
a = torch.rand(4, 3, 28, 28)
print(a.transpose(1, 3).shape) # 交换13维度。
print('=========================================')
b = torch.rand(4, 3, 28, 32)
print(b.transpose(1, 3).shape)
print(b.transpose(1, 3).transpose(1, 2).shape)
print('==================permute================')
print(b.permute(0, 2, 3, 1).shape)
- 运行结果
torch.Size([4, 28, 28, 3])
=========================================
torch.Size([4, 32, 28, 3])
torch.Size([4, 28, 32, 3])
==================permute================
torch.Size([4, 28, 32, 3])
Process finished with exit code 0