蒙特卡洛樹 2017 EC-Final L.SOS

蒙特卡洛樹 2017 EC-Final L.SOS

最近看AlphaZero論文學習了蒙特卡洛樹,隱約記得很久以前EC-Final上有人說可以利用MTC打表,決定練練手。從構建到完成耗時兩天。

蒙特卡洛樹的學習可以參考:https://blog.csdn.net/ljyt2/article/details/78332802

題目可以參考: https://vjudge.net/problem/1500546/origin

蒙特卡洛樹主要分爲tree_travelsal, node_expand,  rollout,backpropogation 四個步驟. 在本例中分別對應choose,expand, rollout, update四個函數or代碼段, rollout調用wander函數來隨機選取局面。

Choose Node

引入UCB函數來評判

當UCB函數只根據平均值$ \overline{V} $來選擇node的時候很容易陷入局部最優導致局面趨於平局,這是由於每一步只有少量的選擇能到達必勝態,隨機漫步難以發現這種必勝態,故需要使得樹發現更多的新節點。

UCB函數中UCB_C值越大,越注重探索, UCB_C值越小越注重原來的選擇。而C值大就意味着需要更多的迭代才能期望結果更接近最優。C值小就更可能陷入局部最優。這與過擬合非常相似。

 

PS:不同於普通的DP算法和枚舉算法,蒙特卡洛樹的優點是可以控制迭代次數或者迭代時間,可以以此來獲得當前的最優解。

代碼不足之處:MCT沒有函數來清空節點內存, ans()對最優節點的選取應該不利用UCB附加值(可以對其重寫),也許可以參照AlphaZero對MCT的變體來加速

 

代碼如下:

/*
author: InFiNiTeemo
substract: This program is used for solving 2017 EC-Final Problem L with low effciency for subtle computer 
to table. many program segments can be optimized as well. It's a tutoral for greenhand to understand MTC 
*/

#include<bits/stdc++.h>
#include<random>
#include<windows.h>
using namespace std;

mt19937 generator((unsigned)time(NULL));
//#define DEBUG
int sum = 0;

class Node {
private:
	const int INF = 1e8;
	const int choose[2] = { 'S', 'O' };
	const int UCT_C = 4;
	//evaluation
	double q, avg_q;
	int c;
	//transformation
	vector<Node*> son;
	//state
	vector<int> G;
	int size;
	Node* father;
	int turn; // turn & 1 enlarge the value, else ensmall the value
	bool leaf;
	int end = -2; //-2 unjudged, -1 lose, 0 don't know/ draw, 1 win
public:
	Node(int sz) :q(INF), c(0), size(sz), G(vector<int>(sz,0)), leaf(true), turn(0){}
	Node(vector<int> previous,int f_turn, int place, int kind) : c(0), size(previous.size()),G(previous), turn(f_turn+1), leaf(true){
		G[place] = kind;
		avg_q = turn & 1 ? INF : -INF;
	}
	void update(double value) {
		q += value;
		c++;
		avg_q = q / c;
	}
	double UCT() {
		if (c == 0) return avg_q;
		return avg_q + (turn&1?1:-1)* pow(UCT_C*log(sum*1.0)/c,0.5);
	}
	void add_edge(Node* son_nd) {
		son.emplace_back(son_nd);
	}
	bool is_leaf() {
		return leaf;
	}
	bool is_full() {
		return size - turn == 0;
	}
	//輸出只可能是1和-1
	int is_end() {
		if (end!=-2) return abs(end);
		else return abs(end = evaluate());
	} 

	/*@The player on the offensive always choose the high score node, whereas the defensive choose the low score node, the score is based on the UCT()*/
	Node* choose_node() {
		try {
			if (turn - size == 0) {
				throw "Node.Choose Node: no left grid for node choosing.";
			}
		}
		catch(const char* msg){
			cerr << msg << endl;
		}
#ifdef DEBUG
		cout << "--------CHOOSE_NODE---------" << endl;
		show();
		cout << "----------------------------" << endl;
		Sleep(300);
#endif // DEBUG
		Node* t = NULL;
		for (auto candidate : son) {
#ifdef DEBUG
			cout << "---------------------" << endl;
			candidate->show();
			cout << "-----------------------" << endl;
			Sleep(300);
#endif // DEBUG

			if (t == NULL) t = candidate;
			else {
				//check
				double delta = candidate->UCT() - t->UCT();
				int _turn = (turn & 1 ? -1 : 1);
				if (_turn*delta > 0) {
					t = candidate;
				}
			}
		}
		return t;
	}

	Node* wander() {
		try {
			if (turn - size == 0) {
				throw "Node.Wander Node: no left grid for wandering.";
			}
		}
		catch (const char* msg) {
			cerr << msg << endl;
		}
		int left = size - turn;
		int nxt = generator()%left+1, kind = generator()%2;
		for (int i = 0; i < size; i++) {
			if (G[i] == 0) {
				nxt--;
			}
			if (nxt == 0) return new Node(G, turn, i, choose[kind]);
		}
	}
	void expand_node() {
		try {
			if (turn - size == 0) {
				throw "Node.Expand_Node: no left grid for node expanding.";
			}
		}
		catch (const char* msg) {
			cerr << msg << endl;
		}
		leaf = false;
		for (int i = 0; i < size; i++) {
			if (G[i] == 0) {
				for (int j = 0; j < 2; j++) {
					Node* son = new Node(G, turn, i, choose[j]);
#ifdef DEBUG
					cout << "-------EXPAND------------" << endl;
					son->show();
					cout << "-----------------------------" << endl;
					Sleep(300);	
#endif

					add_edge(son);
				}
			}
		}
	}
	int visit_count() {
		return c;
	}
	int evaluate() {
		for (int i = 0; i < size - 2; i++) if (G[i] == 'S'&&G[i + 1] == 'O'&&G[i + 2] == 'S') {
			return (turn & 1 ? 1 : -1);
		}
		return 0;
	}
	void show() {
		cout << "Turn: " << turn << endl;
		for (auto x : G) {
			if (x == 0) x = '-';
			cout << setw(4) << char(x);
		}
		cout << endl;
		cout << "value:  " <<UCT() << endl;
	}
};





class MTC_TREE {
private:
	Node* root;
	//size of chessboard size
	int size;
	const int iter = 1e7;

	//debug
	queue<Node*> Rollout_que;

public:
	MTC_TREE(int sz=0):size(sz) {
		root = new Node(sz);
		root->expand_node();
	}

	void game() {
		//loop
		sum = 0;
		for (int i = 0; i < iter; i++) {
			single_game();
			sum++;
		}
		//ans
		ans();
	}
	void single_game(){
		//init

		//choose
		Node* p = root;
		queue<Node*> que;
		while (!p->is_leaf()) {
			p = p->choose_node();
			Node* q = p;
			que.push(q);
		}
		//rollout
		Node* q = roll_out(p);
		//q->show();
		int x = q->evaluate();
		while(!Rollout_que.empty()) {
			Node* t = Rollout_que.front(); Rollout_que.pop();
			delete t;
		}
		//update
		while (!que.empty()) {
			Node* t = que.front(); que.pop();
			if (t->visit_count() == 0 && t->is_leaf() && !t->is_full()) {
				t->expand_node();
			}
			t->update(x);
		}
	}
	Node* roll_out(Node* p) {
		while (!(p->is_full() || p->is_end())) {
			p = p->wander();
			Rollout_que.push(p);
		}
		return p;
	}
	void ans() {
		Node* p = root;
		while (!p->is_leaf()) {
			p = p->choose_node();
		}
		//cout << (p->is_leaf()?"true":"false") << endl;
		p->show();
		string result[] = { "B win", "Draw", "A win" };
		cout << result[p->evaluate()+1] << endl;	
	}
};

int main() {
	for (int i = 3; i <= 100; i++) {
		MTC_TREE tree(i);
		cout << i << ":  ";
		tree.game();
	}
	system("pause");
}

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章