Appleman and Tree CodeForces - 461B(樹形dp)

Appleman has a tree with n vertices. Some of the vertices (at least one) are colored black and other vertices are colored white.

Consider a set consisting of k (0 ≤ k < n) edges of Appleman’s tree. If Appleman deletes these edges from the tree, then it will split into (k + 1) parts. Note, that each part will be a tree with colored vertices.

Now Appleman wonders, what is the number of sets splitting the tree in such a way that each resulting part will have exactly one black vertex? Find this number modulo 1000000007 (109 + 7).

Input
The first line contains an integer n (2  ≤ n ≤ 105) — the number of tree vertices.

The second line contains the description of the tree: n - 1 integers p 0, p 1, …, p n - 2 (0 ≤ p i ≤ i). Where p i means that there is an edge connecting vertex (i + 1) of the tree and vertex p i. Consider tree vertices are numbered from 0 to n - 1.

The third line contains the description of the colors of the vertices: n integers x 0, x 1, …, x n - 1 ( x i is either 0 or 1). If x i is equal to 1, vertex i is colored black. Otherwise, vertex i is colored white.

Output
Output a single integer — the number of ways to split the tree modulo 1000000007 (109 + 7).

Examples
Input
3
0 0
0 1 1
Output
2
Input
6
0 1 1 0 4
1 1 0 0 1 0
Output
1
Input
10
0 1 2 1 4 4 4 0 8
0 0 0 1 0 1 1 0 0 1
Output
27

題意:
一棵樹,每個點可以是黑色也可以是白色。
要求減掉k1k-1個邊分成kk塊使得每個塊只有一個黑點

思路:
定義dp[u][0/1]dp[u][0/1]代表uu爲根節點子樹有沒有黑點,有一個黑點的分法。

則對於dp[u][1]dp[u][1]
如果當前子樹有黑點,要去掉這個邊,也就是dp[u][1]dp[v][1]dp[u][1]*dp[v][1]
如果當前子樹有黑點,可以保留這個邊,也就是dp[u][0]dp[v][1]dp[u][0]*dp[v][1]
如果當前子樹沒有黑點,那麼肯定要保留這個邊,否則子樹連通塊沒有黑點了,也就是dp[u][1]dp[v][0]dp[u][1]*dp[v][0]

對於dp[u][0]dp[u][0]
如果當前子樹沒有黑點,則保留這個邊,也就是dp[u][0]dp[v][0]dp[u][0]*dp[v][0]
如果當前子樹有黑點,則減掉這個邊,也就是dp[u][0]dp[v][1]dp[u][0]*dp[v][1]

然後轉移就好了。

對於這種樹dp計數的問題總不是很感冒,感覺一般就是,乘法原理加上之前子樹劃分的狀態和當前子樹的狀態,最後可能還要去重。

#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <set>
#include <queue>
#include <map>
#include <string>
#include <iostream>
#include <cmath>

using namespace std;
typedef long long ll;

typedef long long ll;
const int maxn = 1e5 + 7;
const int mod = 1e9 + 7;
vector<int>G[maxn];
int a[maxn];
ll dp[maxn][2];

void dfs(int u,int fa) {
    if(a[u] == 1) {
        dp[u][1] = 1;
    } else {
        dp[u][0] = 1;
    }
    for(int i = 0;i < G[u].size();i++) {
        int v = G[u][i];
        if(v == fa) continue;
        dfs(v,u);
        dp[u][1] = (dp[u][1] * dp[v][0] % mod + dp[u][0] * dp[v][1] % mod + dp[u][1] * dp[v][1] % mod) % mod;
        dp[u][0] = (dp[u][0] * dp[v][0] % mod + dp[u][0] * dp[v][1] % mod) % mod;
    }
}

int main() {
    int n;scanf("%d",&n);
    for(int i = 2;i <= n;i++) {
        int x;scanf("%d",&x);x++;
        G[x].push_back(i);
        G[i].push_back(x);
    }
    for(int i = 1;i <= n;i++) scanf("%d",&a[i]);
    dfs(1,-1);
    printf("%lld\n",dp[1][1]);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章