-
Pytorch安裝
推薦使用Anaconda來管理pytorch等python包
由於電腦配置比較老了,原來裝的是windows系統,安裝的Ubuntu虛擬機比較慢,所以只好在windows下來安裝環境了。使用默認的anaconda源經常會出現CondaHTTPError以及無法創建虛擬環境的問題,需要把anaconda源改爲國內的清華大學鏡像就可以了。可以參考[Anaconda 鏡像使用幫助](https://mirror.tuna.tsinghua.edu.cn/help/anaconda/),修改.condarc文件內容爲:
channels:
- defaults
show_channel_urls: true
channel_alias: https://mirrors.tuna.tsinghua.edu.cn/anaconda
default_channels:
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2
custom_channels:
conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
msys2: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
bioconda: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
menpo: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
simpleitk: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
然後打開Anaconda Prompt,創建虛擬環境:
conda create -n pytorch python=3.7
激活虛擬環境:
conda activate pytorch
通過pip安裝pytorch包,可以參考pytorch官方網站根據自己的系統和軟硬件環境選擇合適的版本(由於我的顯卡不支持CUDA,所以安裝的是CPU only的版本):
pip install torch==1.2.0+cpu torchvision==0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
然後可以退出虛擬環境:
deactivate
至此,已經成功創建了另一個conda虛擬環境pytorch,並安裝了pytroch。pycharm的安裝與使用不在此贅述。
可以用以下簡單的代碼來打印出安裝的pytorch的版本:
import torch
print("hello pytorch {}".format(torch.__version__))
Pytorch的基礎數據結構Tensor和Variable
Tensor張量是什麼? 可以把Tensor看作是一個多維數組,是標量(0維張量),向量(1維張量),矩陣(2維張量)的高維拓展。
Variable是torch.autograd中的數據類型,主要用於封裝Tensor,進行自動求導。
data:被包裝的Tensor
grad:data的梯度
grad_fn:創建Tensor的Function,是自動求導的關鍵
requires_grad:指示是否需要梯度
is_leaf:指示是否是葉子結點(張量)
張量的創建
-
直接創建
- 通過torch.tensor()創建
data:可以是list, numpy
dtype:數據類型,默認與data的類型一致
device:所在設備,cuda/cpu
requires_grad:是否需要梯度
import torch
import numpy as np
arr = np.ones((3, 3))
print("ndarray type: ", arr.dtype)
t = torch.tensor(arr)
print(t)
結果如下:
ndarray type: float64
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], dtype=torch.float64)
- 通過torch.from_numpy創建
arr = np.array([[1, 2, 3], [4, 5, 6]])
t = torch.from_numpy(arr)
print(t)
print("array address:", id(arr)) ## show the memory address
print("tensor data address: ", id(t.data))
結果如下:
tensor([[1, 2, 3],
[4, 5, 6]], dtype=torch.int32)
array address: 2796001828096
tensor data address: 2796001869640
可以看出,通過torch.from_numpy創建的tensor和原ndarray共享內存,當修改其中一個的數據,另一個也同時被改動。
-
依據數值創建
t = torch.zeros((3, 3)) # 創建全0張量
"""
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
"""
t = torch.ones((3, 3)) # 創建全1張量
"""
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
"""
t = torch.full((3, 4), 5) # 創建值全爲5的3X3的張
"""
tensor([[5., 5., 5., 5.],
[5., 5., 5., 5.],
[5., 5., 5., 5.]])
"""
# 根據提供的size創建全0,全1,全爲fill_value的張量
arr = np.array([[1, 2, 3], [4, 5, 6]])
test_t = torch.from_numpy(arr)
t = torch.zeros_like(test_t)
"""
tensor([[0, 0, 0],
[0, 0, 0]], dtype=torch.int32)
"""
t = torch.ones_like(test_t)
"""
tensor([[1, 1, 1],
[1, 1, 1]], dtype=torch.int32)
"""
t = torch.full_like(test_t, 3)
"""
tensor([[3, 3, 3],
[3, 3, 3]], dtype=torch.int32)
"""
# 創建等差的1維張量
t = torch.arange(2, 10, 2)
### output
## tensor([2, 4, 6, 8])
# 創建均分的1維張量
t = torch.linspace(1, 3, steps=10)
"""
tensor([1.0000, 1.2222, 1.4444, 1.6667, 1.8889, 2.1111, 2.3333, 2.5556, 2.7778,
3.0000])
"""
# 創建對數均分的1維張量
t = torch.logspace(1, 2, steps=8)
"""
tensor([ 10.0000, 13.8950, 19.3070, 26.8270, 37.2759, 51.7948, 71.9686,
100.0000])
"""
# 創建單位對角矩陣
t = torch.eye(3, 3)
"""
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
"""
-
依據概率分佈創建
# 生成正態分佈(高斯分佈)
t_normal = torch.normal(0., 1., size=(4,))
## tensor([ 0.0206, -0.4805, -0.3185, -1.0287])
# 生成標準正態分佈
t = torch.randn((2, 2))
"""
tensor([[ 0.4790, 0.1725],
[-2.1356, -0.2725]])
"""
# 在區間[0, 1)上生成均勻分佈
t = torch.rand((2, 2))
"""
tensor([[0.8578, 0.9063],
[0.9365, 0.7199]])
"""
# 在區間[low, high)生成整數均勻分佈
t = torch.randint(1, 5, (2, 2))
"""
tensor([[3, 2],
[1, 4]])
"""
# 生成從0到n-1的隨機排列
t = torch.randperm(10)
# tensor([1, 9, 3, 2, 6, 4, 5, 0, 7, 8])