批訓練是什麼東西呢?在之前的迭代訓練代碼中。
for t in range(100):
out = net(x)
loss = loss_func(out,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
一次迭代,需要到用到訓練樣本的所有數據。那麼當訓練集非常大,或者說樣本無法同時取出來的時候,就比較難以訓練,這時候要用上批處理的方法。
何爲批處理呢?就是每次迭代只使用訓練集的一部分作爲一個代表,來訓練整個網絡。這樣可以加速網絡的訓練,同時,精度又不會有太大的下降。
pytorch提供了一些方法來進行批訓練,主要下面兩個
-
Data.TensorDataset
將訓練樣本的x,y封裝起來的一個數據集類型 -
Data.DataLoader
這個是將數據集切分的一個工具,一般都是隨機切分。
上代碼
import torch#導入模塊
import torch.utils.data as Data
#每一批的數據量
BATCH_SIZE=5#每一批的數據量
x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)
#轉換成torch能識別的Dataset
torch_dataset=Data.TensorDataset(x,y) #將數據放入torch_dataset
#torch.utils.data.DataLoader這個接口定義在dataloader.py腳本中,只要是用PyTorch來訓練
#模型都會用到該接口,該接口主要用來將自動以的數據讀取接口的輸出或者PyTorch已有的數據讀取接口按照batch size
#封裝成Tensor,後續只需要再包裝成Variable即可作爲模型的輸入,因此該接口有點承上啓下的作用,比較
#
loader=Data.DataLoader(
dataset=torch_dataset, #將數據放入loader
batch_size=BATCH_SIZE, #批的尺寸,五個爲一個批次
shuffle=True, #是否打斷數據
num_workers=0 #多線程讀取數據,如果爲0就是主線程來讀取數據
)
#for epoch in range(3): #訓練所有的整套數據3次
# for step,(batch_x,batch_y) in enumerate(loader): #
for epoch in range(3):
for step,(batch_x,batch_y) in enumerate(loader):
print('Epoch:',epoch,'|Step:',step,'|batch x:',batch_x.numpy()
,'|batch y:',batch_y.numpy())
loader=Data.DataLoader(
dataset=torch_dataset, #將數據放入loader
batch_size=BATCH_SIZE, #批的尺寸,五個爲一個批次
shuffle=True, #是否打斷數據
num_workers=0 #多線程讀取數據,如果爲0就是主線程來讀取數據
)
dataset就是剛剛建立好的數據集,batch_size是每一批的大小。shuffle,每個批是否是隨機從dataset裏面取數據。num_works,是否是多個線程來讀取數據,我的機器上這個參數爲非零值就會執行失敗,不知道爲什麼
for epoch in range(3):
for step,(batch_x,batch_y) in enumerate(loader):
print('Epoch:',epoch,'|Step:',step,'|batch x:',batch_x.numpy()
,'|batch y:',batch_y.numpy())
for step,(batch_x,batch_y) in enumerate(loader):
這句話是生成了一個關於loader的迭代器,然後遍歷loader
step代表索引,batch_x,batch_y代表每次隨機切分的訓練集,大小爲5
打印效果如下
Epoch: 0 |Step: 0 |batch x: [1. 2. 3. 4. 5.] |batch y: [10. 9. 8. 7. 6.]
Epoch: 0 |Step: 1 |batch x: [ 6. 7. 8. 9. 10.] |batch y: [5. 4. 3. 2. 1.]
Epoch: 1 |Step: 0 |batch x: [1. 2. 3. 4. 5.] |batch y: [10. 9. 8. 7. 6.]
Epoch: 1 |Step: 1 |batch x: [ 6. 7. 8. 9. 10.] |batch y: [5. 4. 3. 2. 1.]
Epoch: 2 |Step: 0 |batch x: [1. 2. 3. 4. 5.] |batch y: [10. 9. 8. 7. 6.]
Epoch: 2 |Step: 1 |batch x: [ 6. 7. 8. 9. 10.] |batch y: [5. 4. 3. 2. 1.]