tensorflow訓練過程控制callback

best_mean_reward, n_steps = -np.inf, 0


def callback(_locals, _globals):
	"""
	Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
	:param _locals: (dict)
	:param _globals: (dict)
	"""
	global n_steps, best_mean_reward
	
	# Print stats every 1000 calls
	if (n_steps + 1) % 10 == 0:
		# Evaluate policy training performance
		x, y = ts2xy(load_results(log_dir), 'timesteps')
		if len(x) > 0:
			mean_reward = np.mean(y[-100:])
			print("timestep:{}, mean reward per 100 episode: {:.2f}".format(x[-1], mean_reward))
			
			# New best model, you could save the model here
			if mean_reward > 195:
				_locals['self'].save(log_dir + 'best_model.pkl')
				return False
				
	n_steps += 1


def train(env, trian_timesteps, algo):
	
	if algo == 'dqn':
		model = DQN('MlpPolicy', env, verbose=0)
	elif algo == 'ppo':
		env = DummyVecEnv([lambda: env])
		model = PPO2(MlpPolicy, env, verbose=0)
	else:
		env = DummyVecEnv([lambda: env])
		model = A2C(MlpPolicy, env, verbose=0)
		
	model.learn(total_timesteps=int(trian_timesteps), callback=callback)
	model.save("./trained_models/{}-{}-{}".format(algo, env_name, trian_timesteps))
	results_plotter.plot_results([log_dir], trian_timesteps, results_plotter.X_TIMESTEPS, "{}, {}".format(algo, env_name))
	plt.savefig("./trained_models/{}-{}-{}.png".format(algo, env_name, trian_timesteps))
	plt.show()
	print('{} {} training finished.'.format(algo, env))
	del model
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章