pytorch學習之七 批訓練

批訓練是什麼東西呢?在之前的迭代訓練代碼中。

for t in range(100):
    out = net(x)
    loss = loss_func(out,y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

一次迭代,需要到用到訓練樣本的所有數據。那麼當訓練集非常大,或者說樣本無法同時取出來的時候,就比較難以訓練,這時候要用上批處理的方法。

何爲批處理呢?就是每次迭代只使用訓練集的一部分作爲一個代表,來訓練整個網絡。這樣可以加速網絡的訓練,同時,精度又不會有太大的下降。

pytorch提供了一些方法來進行批訓練,主要下面兩個

  • Data.TensorDataset
    將訓練樣本的x,y封裝起來的一個數據集類型

  • Data.DataLoader
    這個是將數據集切分的一個工具,一般都是隨機切分。

上代碼

import torch#導入模塊
import torch.utils.data as Data

#每一批的數據量
BATCH_SIZE=5#每一批的數據量

x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)


#轉換成torch能識別的Dataset
torch_dataset=Data.TensorDataset(x,y)  #將數據放入torch_dataset


#torch.utils.data.DataLoader這個接口定義在dataloader.py腳本中,只要是用PyTorch來訓練
#模型都會用到該接口,該接口主要用來將自動以的數據讀取接口的輸出或者PyTorch已有的數據讀取接口按照batch size
#封裝成Tensor,後續只需要再包裝成Variable即可作爲模型的輸入,因此該接口有點承上啓下的作用,比較
#

loader=Data.DataLoader(
        dataset=torch_dataset,      #將數據放入loader
        batch_size=BATCH_SIZE,      #批的尺寸,五個爲一個批次
        shuffle=True,               #是否打斷數據  
        num_workers=0              #多線程讀取數據,如果爲0就是主線程來讀取數據
        )

#for epoch in range(3):  #訓練所有的整套數據3次
#    for step,(batch_x,batch_y) in enumerate(loader):   #
for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):  
        print('Epoch:',epoch,'|Step:',step,'|batch x:',batch_x.numpy()
            ,'|batch y:',batch_y.numpy())

loader=Data.DataLoader(
dataset=torch_dataset, #將數據放入loader
batch_size=BATCH_SIZE, #批的尺寸,五個爲一個批次
shuffle=True, #是否打斷數據
num_workers=0 #多線程讀取數據,如果爲0就是主線程來讀取數據
)

dataset就是剛剛建立好的數據集,batch_size是每一批的大小。shuffle,每個批是否是隨機從dataset裏面取數據。num_works,是否是多個線程來讀取數據,我的機器上這個參數爲非零值就會執行失敗,不知道爲什麼

for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):  
        print('Epoch:',epoch,'|Step:',step,'|batch x:',batch_x.numpy()
            ,'|batch y:',batch_y.numpy())

for step,(batch_x,batch_y) in enumerate(loader):
這句話是生成了一個關於loader的迭代器,然後遍歷loader
step代表索引,batch_x,batch_y代表每次隨機切分的訓練集,大小爲5
打印效果如下

Epoch: 0 |Step: 0 |batch x: [1. 2. 3. 4. 5.] |batch y: [10.  9.  8.  7.  6.]
Epoch: 0 |Step: 1 |batch x: [ 6.  7.  8.  9. 10.] |batch y: [5. 4. 3. 2. 1.]
Epoch: 1 |Step: 0 |batch x: [1. 2. 3. 4. 5.] |batch y: [10.  9.  8.  7.  6.]
Epoch: 1 |Step: 1 |batch x: [ 6.  7.  8.  9. 10.] |batch y: [5. 4. 3. 2. 1.]
Epoch: 2 |Step: 0 |batch x: [1. 2. 3. 4. 5.] |batch y: [10.  9.  8.  7.  6.]
Epoch: 2 |Step: 1 |batch x: [ 6.  7.  8.  9. 10.] |batch y: [5. 4. 3. 2. 1.]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章