在之前的文章中,我們做了如下工作:
- 如何設計一個類flappy-bird小遊戲:【python實戰】使用pygame寫一個flappy-bird類小遊戲 | 設計思路+項目結構+代碼詳解|新手向
- DFS 算法是怎麼回事,我是怎麼應用於該小遊戲的:【深度優先搜索】一個實例+兩張動圖徹底理解DFS|DFS與BFS的區別|用DFS自動控制我們的小遊戲
- BFS 算法是怎麼回事,我是怎麼應用於該小遊戲的:【廣度優先搜索】一個實例+兩張動圖徹底理解BFS|思路+代碼詳解|用DFS自動控制我們的小遊戲
- 強化學習爲什麼有用?其基本原理:無需公式或代碼,用生活實例談談AI自動控制技術“強化學習”算法框架
- 構建一個簡單的卷積神經網絡,使用DRL框架tianshou匹配DQN算法
構造一個簡單的卷積神經網絡,實現 DQN
本文涉及的 .py
文件有:
DQN_train/gym_warpper.py
DQN_train/dqn_train2.py
DQN_train/dqn_render2.py
requirements
tianshou
pytorch > 1.40
gym
繼續訓練與測試
在本項目地址中,你可以使用如下文件對我訓練的模型進行測試,或者繼續訓練。
繼續訓練該模型
python DQN_train/dqn_train2.py
如圖,我已經訓練了 53 次(每次10個epoch),輸入上述命令,你將開始第 54 次訓練,如果不使用任務管理器強制停止,計算機將一直訓練下去,並自動保存最新一代的權重。
查看效果
python DQN_train/dqn_render2.py 0
注意參數 0 ,輸入 0 代表使用最新的權重。
效果如圖:
上圖中,可以看到我們的 AI 已經學會了一些“知識”:比如如何前往下一層;它還需要多加練習,以學會如何避開這些小方塊構成的障礙。
此外,我保留了一些歷史權重。你還可以輸入參數:7, 10, 13, 21, 37, 40, 47,查看訓練次數較少時,神經網絡的表現。
封裝交互環境
強化學習算法有效,很大程度上取決於獎勵機制設計的是否合理。
事件 | 獎勵 |
---|---|
動作後碰撞障礙物、牆壁 | -1 |
動作後無事發生 | 0.1 |
動作後得分 | 1 |
封裝代碼在 gym_wrapper.py 中,使用類 AmazingBrickEnv2
。
強化學習機制與神經網絡的構建
上節中,我們將 2 幀的數據輸入到卷積層中,目的是:
- 讓卷積層提取出“障礙物邊緣”與“玩家位置”;
- 讓 2 幀數據反映出“玩家速度”信息。
爲了節省計算資源,同時加快訓練速度,我們人爲地替機器提取這些信息:
- 不再將巨大的 2 幀“圖像矩陣”輸入到網絡中;
- 取而代之的是,輸入 2 幀的位置信息;
- 即輸入
玩家xy座標
、左障礙物右上頂點xy座標
、右障礙物左上頂點xy座標
、4個障礙方塊的左上頂點的xy座標
(共14個數); - 如此, 2 幀數據共 28 個數字,我們的神經網絡輸入層只有 28 個神經元,比上一個模型(25600)少了不止一個數量級。
我設計的機制爲:
- 每 2 幀進行一次動作決策;
- 狀態的描述變量爲 2 幀的圖像。
線性神經網絡的構建
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28, 128)
self.fc2 = nn.Linear(128, 256)
self.fc3 = nn.Linear(256, 128)
self.fc4 = nn.Linear(128, 3)
def forward(self, obs, state=None, info={}):
if not isinstance(obs, torch.Tensor):
obs = torch.tensor(obs, dtype=torch.float)
x = F.relu(self.fc1(obs))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x, state
如上,共四層線性網絡。
記錄訓練的微型框架
爲了保存訓練好的權重,且在需要時可以暫停並繼續訓練,我新建了一個.json
文件用於保存訓練數據。
dqn2_path = osp.join(path, 'DQN_train/dqn_weights/')
if __name__ == '__main__':
round = 0
try:
# 此處 policy 採用 DQN
# 具體 DQN 構建方法見下文
policy.load_state_dict(torch.load(dqn2_path + 'dqn2.pth'))
lines = []
with open(dqn2_path + 'dqn2_log.json', "r") as f:
for line in f.readlines():
cur_dict = json.loads(line)
lines.append(cur_dict)
log_dict = lines[-1]
print(log_dict)
round = log_dict['round']
del lines
except FileNotFoundError as identifier:
print('\n\nWe shall train a bright new net.\n')
pass
while True:
round += 1
print('\n\nround:{}\n\n'.format(round))
result = ts.trainer.offpolicy_trainer(
policy, train_collector, test_collector,
max_epoch=max_epoch, step_per_epoch=step_per_epoch,
collect_per_step=collect_per_step,
episode_per_test=30, batch_size=64,
train_fn=lambda e: policy.set_eps(0.1 * (max_epoch - e) / round),
test_fn=lambda e: policy.set_eps(0.05 * (max_epoch - e) / round), writer=None)
print(f'Finished training! Use {result["duration"]}')
torch.save(policy.state_dict(), dqn2_path + 'dqn2.pth')
policy.load_state_dict(torch.load(dqn2_path + 'dqn2.pth'))
log_dict = {}
log_dict['round'] = round
log_dict['last_train_time'] = datetime.datetime.now().strftime('%y-%m-%d %I:%M:%S %p %a')
log_dict['result'] = json.dumps(result)
with open(dqn2_path + 'dqn2_log.json', "a+") as f:
f.write('\n')
json.dump(log_dict, f)
DQN
import os.path as osp
import sys
dirname = osp.dirname(__file__)
path = osp.join(dirname, '..')
sys.path.append(path)
from amazing_brick.game.wrapped_amazing_brick import GameState
from amazing_brick.game.amazing_brick_utils import CONST
from DQN_train.gym_wrapper import AmazingBrickEnv2
import tianshou as ts
import torch, numpy as np
from torch import nn
import torch.nn.functional as F
import json
import datetime
train_env = AmazingBrickEnv2()
test_env = AmazingBrickEnv2()
state_shape = 28
action_shape = 1
net = Net()
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
'''args for rl'''
estimation_step = 3
max_epoch = 10
step_per_epoch = 300
collect_per_step = 50
policy = ts.policy.DQNPolicy(net, optim,
discount_factor=0.9, estimation_step=estimation_step,
use_target_network=True, target_update_freq=320)
train_collector = ts.data.Collector(policy, train_env, ts.data.ReplayBuffer(size=2000))
test_collector = ts.data.Collector(policy, test_env)
如圖,採用這種方式訓練了 53 個循環(共計 53 * 10 * 300 = 159000 個 step)效果還是一般。
下一節(也是本項目的最後一節),我們將探討線性網絡解決這個控制問題的相對成功的方案。