讀取數據
import pickle
import gzip
#解壓數據
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
#轉換成tensor格式
import torch
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
創建網絡結構
class Mnist_NN(nn.Module):
def __init__(self):
super().__init__()
self.hidden1 = nn.Linear(784, 128)
self.hidden2 = nn.Linear(128, 256)
self.out = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.hidden1(x))
x = F.relu(self.hidden2(x))
x = self.out(x)
return x
net = Mnist_NN()
print(net)
我們可以打印下網絡結構
Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)
在看下參數情況
for name, parameter in net.named_parameters():
print(name, parameter.size())
hidden1.weight torch.Size([128, 784])
hidden1.bias torch.Size([128])
hidden2.weight torch.Size([256, 128])
hidden2.bias torch.Size([256])
out.weight torch.Size([10, 256])
out.bias torch.Size([10])
使用TensorDataset和DataLoader來構建數據
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
bs = 64#batch
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs )
def get_data(train_ds, valid_ds, bs):
return (
DataLoader(train_ds, batch_size=bs, shuffle=True),#訓練集需要打亂順序
DataLoader(valid_ds, batch_size=bs ),
)
訓練過程
loss_func = F.cross_entropy
#更新梯度
def loss_batch(model, loss_func, xb, yb, opt=None):
loss = loss_func(model(xb), yb)
if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()
return loss.item(), len(xb)
#返回模型及參數
def get_model():
model = Mnist_NN()
return model, optim.SGD(model.parameters(), lr=0.001)
#訓練
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train()
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad():
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
print('當前step:'+str(step), '驗證集損失:'+str(val_loss))
#開始訓練
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
當前step:0 驗證集損失:2.2796445930480957
當前step:1 驗證集損失:2.2440698066711424
當前step:2 驗證集損失:2.1889826164245605
當前step:3 驗證集損失:2.0985311767578123
當前step:4 驗證集損失:1.9517273582458496
當前step:5 驗證集損失:1.7341805934906005
當前step:6 驗證集損失:1.4719875366210937
當前step:7 驗證集損失:1.2273896869659424
當前step:8 驗證集損失:1.0362271406173706
當前step:9 驗證集損失:0.8963696184158325
當前step:10 驗證集損失:0.7927186088562012
當前step:11 驗證集損失:0.7141492074012756
當前step:12 驗證集損失:0.6529350900650024
當前step:13 驗證集損失:0.60417300491333
當前step:14 驗證集損失:0.5643046331882476
當前step:15 驗證集損失:0.5317994566917419
當前step:16 驗證集損失:0.5047958114624024
當前step:17 驗證集損失:0.4813900615692139
當前step:18 驗證集損失:0.4618900228500366
當前step:19 驗證集損失:0.4443243554592133
當前step:20 驗證集損失:0.4297310716629028
當前step:21 驗證集損失:0.416976597738266
當前step:22 驗證集損失:0.406348459148407
當前step:23 驗證集損失:0.3963301926612854
當前step:24 驗證集損失:0.38733808159828187