Pytorch学习笔记之入门实战之用pytorch玩FizzBuzz(二)

Pytorch学习笔记之入门实战之用pytorch玩FizzBuzz(二)

环境说明

from __future__ import print_function
import torch
torch.__version__
'1.4.0'

FizzBuzz

FizzBuzz是一个简单的小游戏。游戏规则如下:从1开始往上数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数,说buzz,当遇到15的倍数,就说fizzbuzz,其他情况下则正常数数。

我们可以写一个简单的小程序来决定要返回正常数值还是fizz, buzz 或者 fizzbuzz。

# one-hot encode the desired outpuss:[number, "fizz", "buzz", "fizzbuzz"]
def fizz_buzz_encode(i):
    if i % 15 == 0: return 3
    elif i%5 == 0: return 2
    elif i%3 == 0: return 1
    else: return 0
def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]
for i in range(1, 16):
    print(fizz_buzz_decode(i, fizz_buzz_encode(i)))

运行结果

1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz

定义模型的输入与输出(训练数据)

import numpy as np
import torch

NUM_DIGIT = 10

# Represent each input by an array of its binary digits
# 将输入转换成二进制表示,除2取余再反向
def binary_encode(i, num_digits):
    return np.array([i>>d & 1 for d in range(num_digits)][::-1])
# 
all_data_x = torch.Tensor([binary_encode(i, NUM_DIGIT) for i in range(1, 2 ** NUM_DIGIT)])
all_data_y = torch.LongTensor([fizz_buzz_encode(i) for i in range(1, 2**NUM_DIGIT)])
if torch.cuda.is_available():
    all_data_x = all_data_x.cuda()
    all_data_y = all_data_y.cuda()

trX = all_data_x[101:] # 922*10
trY = all_data_y[101:] # 922
testX = all_data_x[:100] # 100*10
testY = all_data_y[:100] # 100
print(trX[0], trX.shape)
print(testX[0], testX.shape)
print(testY[0], testY.shape)

用PyTorch定义模型

# Define the model
NUM_HIDDEN = 100
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)
if torch.cuda.is_available():
    model = model.cuda()
  • 为了让我们的模型学会FizzBuzz这个游戏,我们需要定义一个损失函数,和一个优化算法。
  • 这个优化算法会不断优化(降低)损失函数,使得模型的在该任务上取得尽可能低的损失值。
  • 损失值低往往表示我们的模型表现好,损失值高表示我们的模型表现差。
  • 由于FizzBuzz游戏本质上是一个分类问题,我们选用Cross Entropyy Loss函数。
  • 优化函数我们选用Stochastic Gradient Descent。
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)

开始训练模型

BATCH_SIZE = 128
for epoch in range(10001):
    for start in range(0, len(trX), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = trX[start:end]
        batchY = trY[start:end]
        #   
        y_pred = model.forward(batchX)
        loss = loss_fn(y_pred, batchY)
        # 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if epoch % 1000 ==0:
        loss = loss_fn(model(trX), trY).item()
        print("Epoch {} Loss {}".format(epoch, loss))

训练过程

Epoch 0 Loss 1.5171986818313599
Epoch 1000 Loss 0.033820319920778275
Epoch 2000 Loss 0.018580008298158646
Epoch 3000 Loss 0.012838841415941715
Epoch 4000 Loss 0.009773609228432178
Epoch 5000 Loss 0.007854906842112541
Epoch 6000 Loss 0.006554984021931887
Epoch 7000 Loss 0.005604262929409742
Epoch 8000 Loss 0.004887698218226433
Epoch 9000 Loss 0.004323565401136875
Epoch 10000 Loss 0.0038704578764736652

最后我们用训练好的模型尝试在1到100这些数字上玩FizzBuzz游戏

with torch.no_grad():
    resultY = model(testX)
predictions = zip(range(1,101), resultY.max(1)[1].data.tolist())
print([(i,fizz_buzz_decode(i, x)) for (i, x) in predictions])

预测结果

[(1, '1'), (2, '2'), (3, 'fizz'), (4, '4'), (5, 'buzz'), (6, 'fizz'), (7, '7'), (8, '8'), (9, 'fizz'), (10, 'buzz'), (11, '11'), (12, 'fizz'), (13, '13'), (14, '14'), (15, 'fizzbuzz'), (16, '16'), (17, '17'), (18, 'fizz'), (19, '19'), (20, 'buzz'), (21, 'fizz'), (22, '22'), (23, '23'), (24, 'fizz'), (25, 'buzz'), (26, '26'), (27, 'fizz'), (28, '28'), (29, '29'), (30, 'fizzbuzz'), (31, '31'), (32, '32'), (33, 'fizz'), (34, '34'), (35, 'buzz'), (36, 'fizz'), (37, '37'), (38, '38'), (39, 'fizz'), (40, 'buzz'), (41, '41'), (42, 'fizz'), (43, '43'), (44, '44'), (45, 'fizzbuzz'), (46, '46'), (47, '47'), (48, 'fizz'), (49, '49'), (50, 'buzz'), (51, 'fizz'), (52, '52'), (53, '53'), (54, 'fizz'), (55, 'buzz'), (56, '56'), (57, 'fizz'), (58, '58'), (59, '59'), (60, 'fizzbuzz'), (61, '61'), (62, '62'), (63, 'fizz'), (64, '64'), (65, '65'), (66, 'fizz'), (67, '67'), (68, '68'), (69, '69'), (70, 'buzz'), (71, '71'), (72, 'fizz'), (73, '73'), (74, '74'), (75, 'fizzbuzz'), (76, '76'), (77, '77'), (78, 'fizz'), (79, '79'), (80, 'buzz'), (81, 'fizz'), (82, '82'), (83, '83'), (84, '84'), (85, 'buzz'), (86, '86'), (87, 'fizz'), (88, '88'), (89, '89'), (90, 'fizzbuzz'), (91, '91'), (92, '92'), (93, '93'), (94, '94'), (95, 'buzz'), (96, 'fizz'), (97, '97'), (98, '98'), (99, 'fizz'), (100, 'buzz')]

查看准确率

print(np.sum(resultY.cpu().max(1)[1].numpy() == testY.cpu().numpy())/len(testY))
print(resultY.cpu().max(1)[1].numpy() == testY.cpu().numpy())

保存模型参数

torch.save(model.state_dict(), 'params1.pkl')
# 加载模型参数
# model.load_state_dict(torch.load('params.pkl'))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章