Distance in Tree CodeForces - 161D(樹形dp)

A tree is a connected graph that doesn’t contain any cycles.

The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.

You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs ( v, u) and ( u, v) are considered to be the same pair.

Input
The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.

Next n - 1 lines describe the edges as " a i b i" (without the quotes) (1 ≤ a i, b i ≤ n, a i ≠ b i), where a i and b i are the vertices connected by the i-th edge. All given edges are different.

Output
Print a single integer — the number of distinct pairs of the tree’s vertices which have a distance of exactly k between them.

Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.

Examples
Input
5 2
1 2
2 3
3 4
2 5
Output
4
Input
5 3
1 2
2 3
3 4
4 5
Output
2
Note
In the first sample the pairs of vertexes at distance 2 from each other are (1, 3), (1, 5), (3, 5) and (2, 4).

題意:
求一棵樹上距離爲k的路徑數

思路:
定義dp[i][j]dp[i][j]爲以ii爲起點路徑長度爲jj的數目
然後結果爲 dp[v][j1]dp[u][kj]dp[v][j-1]*dp[u][k-j]

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long ll;
const int maxn = 5e4 + 7;
int head[maxn],nex[maxn << 1],to[maxn << 1],tot;
int dp[maxn][505];
int n,k;
ll ans;

void add(int x,int y) {
    to[++tot] = y;
    nex[tot] = head[x];
    head[x] = tot;
}

void dfs(int u,int fa) {
    dp[u][0] = 1;
    for(int i = head[u];i;i = nex[i]) {
        int v = to[i];
        if(v == fa) continue;
        dfs(v,u);
        for(int j = 1;j <= k;j++) { //當前子樹距離
            ans += 1ll * dp[v][j - 1] * dp[u][k - j];
        }
        for(int j = 1;j <= k;j++) {
            dp[u][j] += dp[v][j - 1];
        }
    }
}

int main() {
    scanf("%d%d",&n,&k);
    for(int i = 1;i < n;i++) {
        int x,y;scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    dfs(1,-1);
    printf("%lld\n",ans);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章