『Pytorch笔记2』Pytorch索引与切片以及维度变换!

Pytorch索引与切片以及维度变换!

一. 索引与切片

1.1. index

import torch

a = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a[0, 0].shape)
print(a[0, 0, 2, 4])
  • 运行结果
torch.Size([4, 3, 28, 28])
torch.Size([28, 28])
tensor(0.0166)

1.2. select first/last N

import torch

a = torch.rand(4, 3, 28, 28)
print(a[:2].shape)
print(a[:2, :, :, :].shape)  	# 意思一样的
print(a[:2, :1, :, :].shape)  	# 意思一样的
print(a[:2, 1:, :, :].shape)
print(a[:2, -1:, :, :].shape)
  • 运行结果
torch.Size([2, 3, 28, 28])
torch.Size([2, 3, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 2, 28, 28])
torch.Size([2, 1, 28, 28])

1.3. select by steps

这里: 其实只有一种通用的形式:start : end : step

import torch

a = torch.rand(4, 3, 28, 28)
print(a[:, :, 0:28:2, 0:28:2].shape)  # 意思一样的
print(a[:, :, ::2, ::2].shape)        # 意思一样的
  • 运行结果
torch.Size([4, 3, 14, 14])
torch.Size([4, 3, 14, 14])

1.4. select by specific index

这里: a.index_select(0, torch.tensor([0, 2])).shape 第一参数表示对哪个维度进行采样,第二个参数必须为Tensor。

import torch

a = torch.rand(4, 3, 28, 28)
# 当前对图片张数进行操作,所以给了0;我们采用第0张和第2张
print(a.index_select(0, torch.tensor([0, 2])).shape)

# 第一个维度就是RGB通道。现在取G通道,B通道。
print(a.index_select(1, torch.tensor([1, 2])).shape)
  • 运行结果
torch.Size([2, 3, 28, 28])
torch.Size([4, 2, 28, 28])

Process finished with exit code 0

1.5. 符号…

这里: 符号…代表任意多的维度。

import torch

a = torch.rand(4, 3, 28, 28)
print(a[...].shape)

print(a[0, ...].shape)   # 这里的...等价于下面的:, :, :,
print(a[0].shape)        # 等价的。
print(a[0, :, :, :].shape)
  • 运行结果
torch.Size([4, 3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([3, 28, 28])

Process finished with exit code 0

1.6. select by mask(注意现shape和原没关系)

  • .masked_select()
import torch

a = torch.randn(3, 4)
print(a)

mask = a.ge(0.5)   # ge表示大于等于0.5的数
print(mask)
print('===================================================')
aa = torch.masked_select(a, mask)  # 把大于等于0.5的元素取出来,注意变成向量了,不是矩阵了。
print(aa)
  • 运行结果
tensor([[ 1.0573, -0.1439,  3.9749, -0.1838],
        [-0.4314,  0.0450, -1.0288, -0.5063],
        [-2.8153,  1.4134,  1.5503, -0.5849]])
tensor([[ True, False,  True, False],
        [False, False, False, False],
        [False,  True,  True, False]])
===================================================
tensor([1.0573, 3.9749, 1.4134, 1.5503])

Process finished with exit code 0

二. 维度变换

  • 常用的API

2.1. view和reshape(维度改变)

这里: 这两个API几乎是一模一样的。Pytorch版本是0.3的时候用的函数是view函数,为了和Numpy一致在Pytorch0.4以后,增加了reshape函数。

import torch

a = torch.rand(4, 1, 28, 28)              # 假设表示照片。
print(a.shape)

print(a.view(4, 28*28).shape)             # 这三个是通用的
print(a.reshape(4, 28*28).shape)          # 这三个是通用的
print(torch.reshape(a, [4, 28*28]).shape) # 这三个是通用的
  • 运行结果
torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
torch.Size([4, 784])
torch.Size([4, 784])

Process finished with exit code 0

2.2. squeeze和unsqueeze(维度增加/减少)

这里: 这两个squeeze(如何挤压掉一个维度,默认删除维度的size为1的)和unsqueeze(一个维度如何展开)分别表示挤压,和展开的意思。

  • 例子1
import torch

a = torch.rand(4, 1, 28, 28)              # 假设表示照片。
print(a.shape)

print(a.unsqueeze(0).shape)
print(a.unsqueeze(-1).shape)
print(a.squeeze().shape)            	  # 把1的维度去掉
  • 运行结果
torch.Size([4, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 28, 28])

Process finished with exit code 0
  • 例子2
import torch

a = torch.tensor([1.2, 2.3])
print(a)
print(a.unsqueeze(-1))
print(a.unsqueeze(0))
  • 运行结果
tensor([1.2000, 2.3000])
tensor([[1.2000],
        [2.3000]])
tensor([[1.2000, 2.3000]])

Process finished with exit code 0
  • 例子3:图片操作实实在在的案例。
import torch

# bias相当于给每个channel上的所有像素增加一个偏置
b = torch.rand(32)            # 这里表示bias
f = torch.rand(4, 32, 14, 14)

# 如何把f叠加到b上面呢,首先维度不一样的。插入维度。
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.shape)    # torch.Size([1, 32, 1, 1])

# 扩张成 [4, 32, 14, 14] 后面讲解。

2.3. expand和repeat(维度扩展)

import torch

# bias相当于给每个channel上的所有像素增加一个偏置
b = torch.rand(32)            # 这里表示bias
f = torch.rand(4, 32, 14, 14)

# 如何把f叠加到b上面呢,首先维度不一样的。插入维度。
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.shape)    # torch.Size([1, 32, 1, 1])
print('===================== = 扩张 ======================== ')
# 扩张成 [4, 32, 14, 14] 后面讲解。
c = b.expand(4, 32, 32, 14)
print(c.shape)
print(b.expand(-1, 32, -1, -1).shape)     # -1保持不变的。
  • 运行结果
torch.Size([1, 32, 1, 1])
===================== = 扩张 ======================== 
torch.Size([4, 32, 32, 14])
torch.Size([1, 32, 1, 1])

Process finished with exit code 0

这里: repeat函数和expand函数不一致,其参数里面给出的不像expand参数是新的维度。repeat给出的而是原来的维度要复制repeat的次数,repeat和matlab中的保持一致,不建议使用这个api,主要是占用内存,它重新申请一片内存空间,赋值新的数据

import torch

a = torch.rand([1, 32, 1, 1])
print(a.shape)
print(a.repeat(4, 32, 1, 1).shape)
print('===========================')
print(a.repeat(4, 1, 1, 1).shape)   # repeat和matlab中是一致的
  • 运行结果
torch.Size([1, 32, 1, 1])
torch.Size([4, 1024, 1, 1])
===========================
torch.Size([4, 32, 1, 1])

Process finished with exit code 0

2.4. .t方法矩阵的转置以及更通用的transpose

这里: transpose()函数表示矩阵的维度交换,接受的参数为要交换的哪两个维度。

import torch

a = torch.rand([3, 3])
print(a)
print(a.t())

print('===================================')
aa = torch.rand(4, 3, 32, 32)  # 【b,c,h,w】比如交换1 3维度,【b,w,h,c】
b = aa.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 3, 32, 32)

c = aa.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 32, 32, 3).transpose(1, 3)

# 判断数据和原来的是否一致
print(torch.all(torch.eq(aa, b))) # 不一致

print(torch.all(torch.eq(aa, c))) # 一致的
  • 运行结果
tensor([[0.0356, 0.4528, 0.7141],
        [0.8490, 0.4429, 0.8567],
        [0.3274, 0.6975, 0.7669]])
tensor([[0.0356, 0.8490, 0.3274],
        [0.4528, 0.4429, 0.6975],
        [0.7141, 0.8567, 0.7669]])
===================================
tensor(False)
tensor(True)

Process finished with exit code 0

2.5. permute

这里: transpose()函数一次只能两两交换。【b, c, h, w】=> 【b, w, h, c】,比如原来一个人的图片,交换过后图片可能不是人了,我们还希望变成原来的样子,可以看成多维度交换,其中参数为新的维度顺序。
同样的道理permute函数也会把内存的顺序给打乱,因此要是涉及contious这个错误的时候,需要额外添加.contiguous()函数,来把内存的顺序变得连续。

import torch

a = torch.rand(4, 3, 28, 28)
print(a.transpose(1, 3).shape)      # 交换13维度。
print('=========================================')
b = torch.rand(4, 3, 28, 32)
print(b.transpose(1, 3).shape)
print(b.transpose(1, 3).transpose(1, 2).shape)
print('==================permute================')
print(b.permute(0, 2, 3, 1).shape)
  • 运行结果
torch.Size([4, 28, 28, 3])
=========================================
torch.Size([4, 32, 28, 3])
torch.Size([4, 28, 32, 3])
==================permute================
torch.Size([4, 28, 32, 3])

Process finished with exit code 0
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章