前言
PyTorch中的數據類型爲Tensor,Tensor與Numpy中的ndarray類似,同樣可以用於標量,向量,矩陣乃至更高維度上面的計算。PyTorch中的tensor又包括CPU上的數據類型和GPU上的數據類型,一般GPU上的Tensor是CPU上的Tensor加cuda()函數得到。通過使用Type函數可以查看變量類型。系統默認的torch.Tensor是torch.FloatTensor類型。例如data = torch.Tensor(2,3)是一個2*3的張量,類型爲FloatTensor; data.cuda()就將其轉換爲GPU的張量類型,torch.cuda.FloatTensor類型。
① 基本類型
如圖所示,下面是cpu和gpu版本的張量(Tensor)的基本類型,一共是8種。
- torch.FloatTensor(2, 2) 構建一個2*2
Float
類型的張量
- torch.DoubleTensor(2, 2) 構建一個2*2
Double
類型的張量 - torch.ByteTensor(2, 2) 構建一個2*2
Byte
類型的張量 - torch.CharTensor(2, 2) 構建一個2*2
Char
類型的張量 - torch.ShortTensor(2, 2) 構建一個2*2
Short
類型的張量 - torch.IntTensor(2, 2) 構建一個2*2
Int
類型的張量 - torch.LongTensor(2, 2) 構建一個2*2
Long
類型的張量
官網還介紹了從python的基本數據類型list和科學計算庫numpy.ndarray轉換爲Tensor的例子:
>>> torch.tensor([[1., -1.], [1., -1.]])
tensor([[ 1.0000, -1.0000],
[ 1.0000, -1.0000]])
>>> torch.tensor(np.array([[1, 2, 3], [4, 5, 6]]))
tensor([[ 1, 2, 3],
[ 4, 5, 6]])
② 張量類型之間的轉換
2.1 CPU和GPU的Tensor之間轉換
從cpu –> gpu,使用data.cuda()
即可。
若從gpu –> cpu,則使用data.cpu()
。
2.2 Tensor與Numpy Array之間的轉換
Tensor –> Numpy.ndarray 可以使用 data.numpy()
,其中data的類型爲torch.Tensor
。
Numpy.ndarray –> Tensor 可以使用torch.from_numpy(data)
,其中data的類型爲numpy.ndarray
。
2.3 Tensor的基本類型轉換(也就是float轉double,轉byte這種。)
爲了方便測試,我們構建一個新的張量,你要轉變成不同的類型只需要根據自己的需求選擇即可
tensor = torch.Tensor(2, 5)
torch.long() 將tensor投射爲long類型
newtensor = tensor.long()torch.half()將tensor投射爲半精度浮點(16位浮點)類型
newtensor = tensor.half()torch.int()將該tensor投射爲int類型
newtensor = tensor.int()torch.double()將該tensor投射爲double類型
newtensor = tensor.double()torch.float()將該tensor投射爲float類型
newtensor = tensor.float()torch.char()將該tensor投射爲char類型
newtensor = tensor.char()torch.byte()將該tensor投射爲byte類型
newtensor = tensor.byte()torch.short()將該tensor投射爲short類型
newtensor = tensor.short()
思考
據我目前使用來看,最常用的還是Tensor.byte(), Tensor.float()。因爲pytorch底層很多計算的邏輯默認需要的是這些類型。但是如果當你需要提高精度,比如說想把模型從float變爲double。那麼可以將要訓練的模型設置爲model = model.double()
。此外,還要對所有的張量進行設置:pytorch.set_default_tensor_type('torch.DoubleTensor')
,不過double比float要慢很多,要結合實際情況進行思考。
參考資料
[1] pytorch張量torch.Tensor類型的構建與相互轉換以及torch.type()和torch.type_as()的用法
[2] PyTorch torch.Tensor 教程