pytorch-Mnist分類任務

讀取數據
import pickle
import gzip
#解壓數據
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
#轉換成tensor格式        
import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
創建網絡結構
class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out  = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x
        
net = Mnist_NN()
print(net)

我們可以打印下網絡結構

Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)

在看下參數情況

for name, parameter in net.named_parameters():
    print(name, parameter.size())

hidden1.weight torch.Size([128, 784])
hidden1.bias torch.Size([128])
hidden2.weight torch.Size([256, 128])
hidden2.bias torch.Size([256])
out.weight torch.Size([10, 256])
out.bias torch.Size([10])

使用TensorDataset和DataLoader來構建數據
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

bs = 64#batch

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs )


def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),#訓練集需要打亂順序
        DataLoader(valid_ds, batch_size=bs ),
    )
訓練過程
loss_func = F.cross_entropy

#更新梯度
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)
    
#返回模型及參數    
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
 
 
#訓練
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('當前step:'+str(step), '驗證集損失:'+str(val_loss))
        
        
#開始訓練        
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)

當前step:0 驗證集損失:2.2796445930480957
當前step:1 驗證集損失:2.2440698066711424
當前step:2 驗證集損失:2.1889826164245605
當前step:3 驗證集損失:2.0985311767578123
當前step:4 驗證集損失:1.9517273582458496
當前step:5 驗證集損失:1.7341805934906005
當前step:6 驗證集損失:1.4719875366210937
當前step:7 驗證集損失:1.2273896869659424
當前step:8 驗證集損失:1.0362271406173706
當前step:9 驗證集損失:0.8963696184158325
當前step:10 驗證集損失:0.7927186088562012
當前step:11 驗證集損失:0.7141492074012756
當前step:12 驗證集損失:0.6529350900650024
當前step:13 驗證集損失:0.60417300491333
當前step:14 驗證集損失:0.5643046331882476
當前step:15 驗證集損失:0.5317994566917419
當前step:16 驗證集損失:0.5047958114624024
當前step:17 驗證集損失:0.4813900615692139
當前step:18 驗證集損失:0.4618900228500366
當前step:19 驗證集損失:0.4443243554592133
當前step:20 驗證集損失:0.4297310716629028
當前step:21 驗證集損失:0.416976597738266
當前step:22 驗證集損失:0.406348459148407
當前step:23 驗證集損失:0.3963301926612854
當前step:24 驗證集損失:0.38733808159828187

完整代碼

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章