HDU 5593 ZYB's Tree(樹形DP 好題(java))



大致題意:

有n = 500000節點的樹, 對於每個節點求距離此節點不超過K (K <= 10)的節點有多少個,把這個n個答案XOR後輸出

思路:

題意中的邊是通過 “For reading:we have two numbers A and B,let fai be the father of node i,fa1=0,fai=(Ai+B)%(i1)+1 for i[2,N] .” 

這樣構造出來的,這只是便於快速讀入而已,沒有給解題帶來作用

注意到K最多隻有10個, 然後用dp[n][K],代表對於n節點距離不超過K的有多少個節點。用樹形DP很好求出在n點下方的不超過K的節點數,對於n點上方

的就不好處理了。

所以用兩輪樹形DP,一個求dpd[n][K]位於n點下方的距離不超過K的節點數(從下往上dp), 另一個求dpu[n][K] 位於n點上方的距離不超過K的節點數(從上往下dp)。

看代碼時不難理解遞推公式的


//1560MS	107360K	5712B	Java	2015-12-05 23:37:51
import java.util.HashMap;
import java.util.Map;
import java.math.*;
import java.io.BufferedReader;
import java.io.OutputStream;
import java.io.IOException;  
import java.io.InputStream;  
import java.io.InputStreamReader;  
import java.math.BigInteger;  
import java.util.StringTokenizer;  
import java.io.PrintWriter;
import java.util.Arrays;

public class Main {

    public static void main(String[] args) {
        InputStream inputStream = System.in;
        OutputStream outputStream = System.out;
        Scanner in = new Scanner(inputStream);
        PrintWriter out = new PrintWriter(outputStream);
        TaskE solver = new TaskE();
        solver.solve(1, in, out);
        out.close();
    }
    static int ans, ecnt;
    static int n, A, K, B;
    static class TaskE {
        public void solve(int testNumber, Scanner in, PrintWriter out) {
        	int T = in.nextInt();
        	while( T-- != 0) {
        		ecnt = ans = 0;
        		n = in.nextInt();
        		K = in.nextInt();
        		A = in.nextInt();
        		B = in.nextInt();
        		int head[] = new int[n + 10], es[] = new int[n + 10];
        		int nxt[] = new int[n + 10], fa[] = new int[n + 10];
        		int dpd[][] = new int[n + 10][K + 2], dpu[][] = new int[n + 10][K + 2];
        		for(int i = 1; i <= n; i ++) head[i] = 0;
        		for(int i = 2; i <= n; i ++) {
        			fa[i] =(int)( ((long)A *i + B) % (i - 1) + 1);
        			add(fa[i], i, es, head, nxt);
        		}
        		DFS1(1, dpd, nxt, es, head);

        		for(int i = head[1]; i != 0; i = nxt[i])
        			DFS2(es[i], fa, dpd, dpu, nxt, es, head);
        		int tmp = 0;
        		for(int j = 0; j <= K; j ++) tmp += dpu[1][j] + dpd[1][j];
        		ans ^= tmp;
        		out.println(ans); 
        		out.flush();
        	}
        	
        }
    }
    static void add(int u, int v, int es[],int head[], int nxt[]) {
    	es[++ecnt] = v;
    	nxt[ecnt] = head[u];
    	head[u] = ecnt;
    }
    static void DFS1(int u, int dpd[][], int nxt[], int es[], int head[]) {
    	dpd[u][0] = 1;
    	for(int i = head[u]; i != 0; i = nxt[i]) {
    		int v = es[i];
    		DFS1(v, dpd, nxt, es, head);
    		for(int j = 1; j <= K; j ++) dpd[u][j] += dpd[v][j - 1]; 
    	}
    }
    static void DFS2(int u,int fa[], int dpd[][], int dpu[][], int nxt[], int es[], int head[]) {
    	int pa = fa[u];
    	dpu[u][1] = dpd[pa][0];
    	for(int i = 1; i < K; i ++) 
    		dpu[u][i + 1] = (dpd[pa][i] - dpd[u][i-1]) + dpu[pa][i];
    	int tmp = 0;
    	for(int j = 0; j <= K; j ++) tmp += dpu[u][j] + dpd[u][j];
    	ans ^= tmp;
    	for(int i = head[u]; i != 0; i = nxt[i]) DFS2(es[i], fa, dpd, dpu, nxt, es, head);
    }


    static class pii implements Comparable<pii> {
		int X, Y;
		pii() {
			this.X = 0;
			this.Y = 0;
		}
		pii(int X, int Y) {
			this.X = X;
			this.Y = Y;
		}
		public int compareTo(pii a) {
			if(this.X - a.X != 0) return this.X - a.X;
			else return this.Y - a.Y;
		}
    }
    static class Scanner {  
    	BufferedReader br;  
		StringTokenizer st;  
			  
		public Scanner(InputStream in) {
			br = new BufferedReader(new InputStreamReader(in));
			eat("");
		} 
		
		private void eat(String s) {  
			st = new StringTokenizer(s);
		}  
	  
		public String nextLine() {  
			try {  
				return br.readLine();  
			} catch (IOException e) {  
				return null;  
			}  
		}  
	  
		public boolean hasNext() {  
			while (!st.hasMoreTokens()) {  
				String s = nextLine();  
				if (s == null)  
				return false;  
				eat(s);  
			}  
			return true;  
		}  
	  
		public String next() {  
			hasNext();  
			return st.nextToken();  
		}  
	  
		public int nextInt() {  
			return Integer.parseInt(next());  
		}  
		  
		public long nextLong() {  
			return Long.parseLong(next());  
		}  
			  
		public double nextDouble() {  
			return Double.parseDouble(next());  
		}  
		  
		public BigInteger nextBigInteger() {  
			return new BigInteger(next());  
		}  
		  
		public int[] nextIntArray(int n) {  
			int[] is = new int[n];  
			for (int i = 0; i < n; i++) {  
				is[i] = nextInt();  
			}  
			return is;  
		}  
	  
		public long[] nextLongArray(int n) {  
			long[] ls = new long[n];  
			for (int i = 0; i < n; i++) {  
				ls[i] = nextLong();  
			}  
			return ls;  
		}  
		 
		public double[] nextDoubleArray(int n) {  
			double[] ds = new double[n];  
			for (int i = 0; i < n; i++) {  
				ds[i] = nextDouble();  
			}  
			return ds;  
		}  
	  
		public BigInteger[] nextBigIntegerArray(int n) {  
			BigInteger[] bs = new BigInteger[n];  
			for (int i = 0; i < n; i++) {  
				bs[i] = nextBigInteger();  
			}  
			return bs;  
		}  
	  
		public int[][] nextIntMatrix(int row, int col) {  
			int[][] mat = new int[row][];  
			for (int i = 0; i < row; i++) {  
				mat[i] = nextIntArray(col);  
			}  
			return mat;  
		}  
	  
		public long[][] nextLongMatrix(int row, int col) {  
			long[][] mat = new long[row][];  
			for (int i = 0; i < row; i++) {  
				mat[i] = nextLongArray(col);  
			}  
			return mat;  
		}  
	  
		public double[][] nextDoubleMatrix(int row, int col) {  
			double[][] mat = new double[row][];  
			for (int i = 0; i < row; i++) {  
				mat[i] = nextDoubleArray(col);  
			}  
			return mat;  
		}  
	  
		public BigInteger[][] nextBigIntegerMatrix(int row, int col) {  
			BigInteger[][] mat = new BigInteger[row][];  
			for (int i = 0; i < row; i++) {  
				mat[i] = nextBigIntegerArray(col);  
			}  
			return mat;  
		}  
    }  
} 


發佈了308 篇原創文章 · 獲贊 11 · 訪問量 24萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章