Appleman and Tree CodeForces - 461B(树形dp)

Appleman has a tree with n vertices. Some of the vertices (at least one) are colored black and other vertices are colored white.

Consider a set consisting of k (0 ≤ k < n) edges of Appleman’s tree. If Appleman deletes these edges from the tree, then it will split into (k + 1) parts. Note, that each part will be a tree with colored vertices.

Now Appleman wonders, what is the number of sets splitting the tree in such a way that each resulting part will have exactly one black vertex? Find this number modulo 1000000007 (109 + 7).

Input
The first line contains an integer n (2  ≤ n ≤ 105) — the number of tree vertices.

The second line contains the description of the tree: n - 1 integers p 0, p 1, …, p n - 2 (0 ≤ p i ≤ i). Where p i means that there is an edge connecting vertex (i + 1) of the tree and vertex p i. Consider tree vertices are numbered from 0 to n - 1.

The third line contains the description of the colors of the vertices: n integers x 0, x 1, …, x n - 1 ( x i is either 0 or 1). If x i is equal to 1, vertex i is colored black. Otherwise, vertex i is colored white.

Output
Output a single integer — the number of ways to split the tree modulo 1000000007 (109 + 7).

Examples
Input
3
0 0
0 1 1
Output
2
Input
6
0 1 1 0 4
1 1 0 0 1 0
Output
1
Input
10
0 1 2 1 4 4 4 0 8
0 0 0 1 0 1 1 0 0 1
Output
27

题意:
一棵树,每个点可以是黑色也可以是白色。
要求减掉k1k-1个边分成kk块使得每个块只有一个黑点

思路:
定义dp[u][0/1]dp[u][0/1]代表uu为根节点子树有没有黑点,有一个黑点的分法。

则对于dp[u][1]dp[u][1]
如果当前子树有黑点,要去掉这个边,也就是dp[u][1]dp[v][1]dp[u][1]*dp[v][1]
如果当前子树有黑点,可以保留这个边,也就是dp[u][0]dp[v][1]dp[u][0]*dp[v][1]
如果当前子树没有黑点,那么肯定要保留这个边,否则子树连通块没有黑点了,也就是dp[u][1]dp[v][0]dp[u][1]*dp[v][0]

对于dp[u][0]dp[u][0]
如果当前子树没有黑点,则保留这个边,也就是dp[u][0]dp[v][0]dp[u][0]*dp[v][0]
如果当前子树有黑点,则减掉这个边,也就是dp[u][0]dp[v][1]dp[u][0]*dp[v][1]

然后转移就好了。

对于这种树dp计数的问题总不是很感冒,感觉一般就是,乘法原理加上之前子树划分的状态和当前子树的状态,最后可能还要去重。

#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <set>
#include <queue>
#include <map>
#include <string>
#include <iostream>
#include <cmath>

using namespace std;
typedef long long ll;

typedef long long ll;
const int maxn = 1e5 + 7;
const int mod = 1e9 + 7;
vector<int>G[maxn];
int a[maxn];
ll dp[maxn][2];

void dfs(int u,int fa) {
    if(a[u] == 1) {
        dp[u][1] = 1;
    } else {
        dp[u][0] = 1;
    }
    for(int i = 0;i < G[u].size();i++) {
        int v = G[u][i];
        if(v == fa) continue;
        dfs(v,u);
        dp[u][1] = (dp[u][1] * dp[v][0] % mod + dp[u][0] * dp[v][1] % mod + dp[u][1] * dp[v][1] % mod) % mod;
        dp[u][0] = (dp[u][0] * dp[v][0] % mod + dp[u][0] * dp[v][1] % mod) % mod;
    }
}

int main() {
    int n;scanf("%d",&n);
    for(int i = 2;i <= n;i++) {
        int x;scanf("%d",&x);x++;
        G[x].push_back(i);
        G[i].push_back(x);
    }
    for(int i = 1;i <= n;i++) scanf("%d",&a[i]);
    dfs(1,-1);
    printf("%lld\n",dp[1][1]);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章