pytorch(五):批訓練

import torch
import torch.utils.data as Data


# 虛構要訓練的數據
x = torch.linspace(11, 20, 10)  # 在[11, 20]裏取出10個間隔相等的數 (torch tensor)
y = torch.linspace(20, 11, 10)


BATCH_SIZE = 5  # 每批需要訓練的數據個數


# 把tensor轉換成torch能識別的數據集
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)


# 把數據集放進數據裝載機裏
loader = Data.DataLoader(
    dataset=torch_dataset,  # 數據集
    batch_size=BATCH_SIZE,  # 每批需要訓練的數據個數
    shuffle=True,  # 是否打亂取數據的順序(打亂的訓練效果更好)
    num_workers=2,  # 多線程讀取數據
)


# 批量取出數據來訓練
for epoch in range(3):  # 把整套數據重複訓練3遍
    for step, (batch_x, batch_y) in enumerate(loader):  # 每次從數據裝載機裏取出批量數據來訓練
        # 以下爲訓練的地方
        # …………
        # 把每遍裏每次取出的數據打印出來
        print('Epoch:', epoch, '|Step:', step,  # Epoch表示哪一遍, Step表示哪一次
              'batch x:', batch_x.numpy(),
              'batch y:', batch_y.numpy(),
        )

運行結果:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章