『Pytorch筆記2』Pytorch索引與切片以及維度變換!

Pytorch索引與切片以及維度變換!

一. 索引與切片

1.1. index

import torch

a = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a[0, 0].shape)
print(a[0, 0, 2, 4])
  • 運行結果
torch.Size([4, 3, 28, 28])
torch.Size([28, 28])
tensor(0.0166)

1.2. select first/last N

import torch

a = torch.rand(4, 3, 28, 28)
print(a[:2].shape)
print(a[:2, :, :, :].shape)  	# 意思一樣的
print(a[:2, :1, :, :].shape)  	# 意思一樣的
print(a[:2, 1:, :, :].shape)
print(a[:2, -1:, :, :].shape)
  • 運行結果
torch.Size([2, 3, 28, 28])
torch.Size([2, 3, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 2, 28, 28])
torch.Size([2, 1, 28, 28])

1.3. select by steps

這裏: 其實只有一種通用的形式:start : end : step

import torch

a = torch.rand(4, 3, 28, 28)
print(a[:, :, 0:28:2, 0:28:2].shape)  # 意思一樣的
print(a[:, :, ::2, ::2].shape)        # 意思一樣的
  • 運行結果
torch.Size([4, 3, 14, 14])
torch.Size([4, 3, 14, 14])

1.4. select by specific index

這裏: a.index_select(0, torch.tensor([0, 2])).shape 第一參數表示對哪個維度進行採樣,第二個參數必須爲Tensor。

import torch

a = torch.rand(4, 3, 28, 28)
# 當前對圖片張數進行操作,所以給了0;我們採用第0張和第2張
print(a.index_select(0, torch.tensor([0, 2])).shape)

# 第一個維度就是RGB通道。現在取G通道,B通道。
print(a.index_select(1, torch.tensor([1, 2])).shape)
  • 運行結果
torch.Size([2, 3, 28, 28])
torch.Size([4, 2, 28, 28])

Process finished with exit code 0

1.5. 符號…

這裏: 符號…代表任意多的維度。

import torch

a = torch.rand(4, 3, 28, 28)
print(a[...].shape)

print(a[0, ...].shape)   # 這裏的...等價於下面的:, :, :,
print(a[0].shape)        # 等價的。
print(a[0, :, :, :].shape)
  • 運行結果
torch.Size([4, 3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([3, 28, 28])

Process finished with exit code 0

1.6. select by mask(注意現shape和原沒關係)

  • .masked_select()
import torch

a = torch.randn(3, 4)
print(a)

mask = a.ge(0.5)   # ge表示大於等於0.5的數
print(mask)
print('===================================================')
aa = torch.masked_select(a, mask)  # 把大於等於0.5的元素取出來,注意變成向量了,不是矩陣了。
print(aa)
  • 運行結果
tensor([[ 1.0573, -0.1439,  3.9749, -0.1838],
        [-0.4314,  0.0450, -1.0288, -0.5063],
        [-2.8153,  1.4134,  1.5503, -0.5849]])
tensor([[ True, False,  True, False],
        [False, False, False, False],
        [False,  True,  True, False]])
===================================================
tensor([1.0573, 3.9749, 1.4134, 1.5503])

Process finished with exit code 0

二. 維度變換

  • 常用的API

2.1. view和reshape(維度改變)

這裏: 這兩個API幾乎是一模一樣的。Pytorch版本是0.3的時候用的函數是view函數,爲了和Numpy一致在Pytorch0.4以後,增加了reshape函數。

import torch

a = torch.rand(4, 1, 28, 28)              # 假設表示照片。
print(a.shape)

print(a.view(4, 28*28).shape)             # 這三個是通用的
print(a.reshape(4, 28*28).shape)          # 這三個是通用的
print(torch.reshape(a, [4, 28*28]).shape) # 這三個是通用的
  • 運行結果
torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
torch.Size([4, 784])
torch.Size([4, 784])

Process finished with exit code 0

2.2. squeeze和unsqueeze(維度增加/減少)

這裏: 這兩個squeeze(如何擠壓掉一個維度,默認刪除維度的size爲1的)和unsqueeze(一個維度如何展開)分別表示擠壓,和展開的意思。

  • 例子1
import torch

a = torch.rand(4, 1, 28, 28)              # 假設表示照片。
print(a.shape)

print(a.unsqueeze(0).shape)
print(a.unsqueeze(-1).shape)
print(a.squeeze().shape)            	  # 把1的維度去掉
  • 運行結果
torch.Size([4, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 28, 28])

Process finished with exit code 0
  • 例子2
import torch

a = torch.tensor([1.2, 2.3])
print(a)
print(a.unsqueeze(-1))
print(a.unsqueeze(0))
  • 運行結果
tensor([1.2000, 2.3000])
tensor([[1.2000],
        [2.3000]])
tensor([[1.2000, 2.3000]])

Process finished with exit code 0
  • 例子3:圖片操作實實在在的案例。
import torch

# bias相當於給每個channel上的所有像素增加一個偏置
b = torch.rand(32)            # 這裏表示bias
f = torch.rand(4, 32, 14, 14)

# 如何把f疊加到b上面呢,首先維度不一樣的。插入維度。
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.shape)    # torch.Size([1, 32, 1, 1])

# 擴張成 [4, 32, 14, 14] 後面講解。

2.3. expand和repeat(維度擴展)

import torch

# bias相當於給每個channel上的所有像素增加一個偏置
b = torch.rand(32)            # 這裏表示bias
f = torch.rand(4, 32, 14, 14)

# 如何把f疊加到b上面呢,首先維度不一樣的。插入維度。
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.shape)    # torch.Size([1, 32, 1, 1])
print('===================== = 擴張 ======================== ')
# 擴張成 [4, 32, 14, 14] 後面講解。
c = b.expand(4, 32, 32, 14)
print(c.shape)
print(b.expand(-1, 32, -1, -1).shape)     # -1保持不變的。
  • 運行結果
torch.Size([1, 32, 1, 1])
===================== = 擴張 ======================== 
torch.Size([4, 32, 32, 14])
torch.Size([1, 32, 1, 1])

Process finished with exit code 0

這裏: repeat函數和expand函數不一致,其參數裏面給出的不像expand參數是新的維度。repeat給出的而是原來的維度要複製repeat的次數,repeat和matlab中的保持一致,不建議使用這個api,主要是佔用內存,它重新申請一片內存空間,賦值新的數據

import torch

a = torch.rand([1, 32, 1, 1])
print(a.shape)
print(a.repeat(4, 32, 1, 1).shape)
print('===========================')
print(a.repeat(4, 1, 1, 1).shape)   # repeat和matlab中是一致的
  • 運行結果
torch.Size([1, 32, 1, 1])
torch.Size([4, 1024, 1, 1])
===========================
torch.Size([4, 32, 1, 1])

Process finished with exit code 0

2.4. .t方法矩陣的轉置以及更通用的transpose

這裏: transpose()函數表示矩陣的維度交換,接受的參數爲要交換的哪兩個維度。

import torch

a = torch.rand([3, 3])
print(a)
print(a.t())

print('===================================')
aa = torch.rand(4, 3, 32, 32)  # 【b,c,h,w】比如交換1 3維度,【b,w,h,c】
b = aa.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 3, 32, 32)

c = aa.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 32, 32, 3).transpose(1, 3)

# 判斷數據和原來的是否一致
print(torch.all(torch.eq(aa, b))) # 不一致

print(torch.all(torch.eq(aa, c))) # 一致的
  • 運行結果
tensor([[0.0356, 0.4528, 0.7141],
        [0.8490, 0.4429, 0.8567],
        [0.3274, 0.6975, 0.7669]])
tensor([[0.0356, 0.8490, 0.3274],
        [0.4528, 0.4429, 0.6975],
        [0.7141, 0.8567, 0.7669]])
===================================
tensor(False)
tensor(True)

Process finished with exit code 0

2.5. permute

這裏: transpose()函數一次只能兩兩交換。【b, c, h, w】=> 【b, w, h, c】,比如原來一個人的圖片,交換過後圖片可能不是人了,我們還希望變成原來的樣子,可以看成多維度交換,其中參數爲新的維度順序。
同樣的道理permute函數也會把內存的順序給打亂,因此要是涉及contious這個錯誤的時候,需要額外添加.contiguous()函數,來把內存的順序變得連續。

import torch

a = torch.rand(4, 3, 28, 28)
print(a.transpose(1, 3).shape)      # 交換13維度。
print('=========================================')
b = torch.rand(4, 3, 28, 32)
print(b.transpose(1, 3).shape)
print(b.transpose(1, 3).transpose(1, 2).shape)
print('==================permute================')
print(b.permute(0, 2, 3, 1).shape)
  • 運行結果
torch.Size([4, 28, 28, 3])
=========================================
torch.Size([4, 32, 28, 3])
torch.Size([4, 28, 32, 3])
==================permute================
torch.Size([4, 28, 32, 3])

Process finished with exit code 0
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章