best_mean_reward, n_steps = -np.inf, 0
def callback(_locals, _globals):
"""
Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
:param _locals: (dict)
:param _globals: (dict)
"""
global n_steps, best_mean_reward
# Print stats every 1000 calls
if (n_steps + 1) % 10 == 0:
# Evaluate policy training performance
x, y = ts2xy(load_results(log_dir), 'timesteps')
if len(x) > 0:
mean_reward = np.mean(y[-100:])
print("timestep:{}, mean reward per 100 episode: {:.2f}".format(x[-1], mean_reward))
# New best model, you could save the model here
if mean_reward > 195:
_locals['self'].save(log_dir + 'best_model.pkl')
return False
n_steps += 1
def train(env, trian_timesteps, algo):
if algo == 'dqn':
model = DQN('MlpPolicy', env, verbose=0)
elif algo == 'ppo':
env = DummyVecEnv([lambda: env])
model = PPO2(MlpPolicy, env, verbose=0)
else:
env = DummyVecEnv([lambda: env])
model = A2C(MlpPolicy, env, verbose=0)
model.learn(total_timesteps=int(trian_timesteps), callback=callback)
model.save("./trained_models/{}-{}-{}".format(algo, env_name, trian_timesteps))
results_plotter.plot_results([log_dir], trian_timesteps, results_plotter.X_TIMESTEPS, "{}, {}".format(algo, env_name))
plt.savefig("./trained_models/{}-{}-{}.png".format(algo, env_name, trian_timesteps))
plt.show()
print('{} {} training finished.'.format(algo, env))
del model
tensorflow訓練過程控制callback
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.