POJ 1741 Tree 解題報告(樹分治)

Tree
Time Limit: 1000MS   Memory Limit: 30000K
Total Submissions: 10589   Accepted: 3257

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001). 
Define dist(u,v)=The min distance between node u and v. 
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. 
Write a program that will count how many pairs which are valid for a given tree. 

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l. 
The last test case is followed by two zeros. 

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

Source


    解題報告:樓教主男人八題之一。樹分治。

    我覺得難點在於,在你計算複雜度之前,你肯定不會想到這麼做。

    每次同此當前每個節點子樹的節點數,找到重心。這個複雜度爲O(n)。統計重心到所有節點的距離,找到所有長度小於等於K的鏈。這裏可以先排序,用O(n)的算法找到所有解。排序複雜度O(n log n)。刪除重心,去重,遞歸下去。每次找的都是重心,可以保證遞歸的深度不超過log n。故總複雜度爲O(n log n log n)。

    問題就解決啦……代碼如下:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#include <string>
#include <iomanip>
#include <cassert>
using namespace std;
#pragma comment(linker, "/STACK:1024000000,1024000000")

#define ff(i, n) for(int i=0;i<(n);i++)
#define fff(i, n, m) for(int i=(n);i<=(m);i++)
#define dff(i, n, m) for(int i=(n);i>=(m);i--)
#define travel(e, u) for(int e = u, v = vv[u]; e; e = nxt[e], v = vv[e])
#define bit(n) (1LL<<(n))
typedef long long LL;
typedef unsigned long long ULL;
void work();
int main()
{
#ifdef ACM
    freopen("in.txt", "r", stdin);
#endif // ACM

    work();
}

void scanf(int & x, char ch = 0)
{
    while((ch = getchar()) < '0' || ch > '9');
    x = ch - '0';
    while((ch = getchar()) >= '0' && ch <= '9') x = 10 * x + (ch - '0');
}

/***************************************************************************************/

const int maxv = 11111;

int n, k;
int ans;

int edge[maxv], ecnt;
int nxt[maxv * 2], vv[maxv * 2], ww[maxv * 2];

bool vis[maxv];
int siz[maxv], mson[maxv];

int mi, root;
int tot, dis[maxv];

void init()
{
    ans = 0;
    ecnt = 2;
    memset(edge, 0, sizeof(edge));
    memset(vis, 0, sizeof(vis));
}

void addEdge(int u, int v, int w, int first[])
{
    nxt[ecnt] = first[u], vv[ecnt] = v, ww[ecnt] = w, first[u] = ecnt++;
}

void dfsSize(int u, int f)
{
    siz[u] = 1;
    mson[u] = 0;

    travel(e, edge[u]) if(!vis[v] && v != f)
    {
        dfsSize(v, u);
        siz[u] += siz[v];
        mson[u] = max(mson[u], siz[v]);
    }
}

void dfsGravity(int r, int u, int f)
{
    mson[u] = max(mson[u], siz[r] - siz[u]);
    if(mson[u] < mi) mi = mson[u], root = u;

    travel(e, edge[u]) if(!vis[v] && v != f)
        dfsGravity(r, v, u);
}

void dfsDis(int u, int f, int d)
{
    dis[tot++] = d;

    travel(e, edge[u]) if(!vis[v] && v != f)
        dfsDis(v, u, d + ww[e]);
}

int calc(int u, int d = 0)
{
    tot = 0;
    dfsDis(u, 0, d);

    sort(dis, dis + tot);

    int ret = 0;
    int l = 0, r = tot - 1;
    while(l < r)
    {
        while(dis[l] + dis[r] > k && l < r) r--;
        ret += r - l;
        l++;
    }

    return ret;
}

void dfs(int u)
{
    mi = n;

    dfsSize(u, 0);
    dfsGravity(u, u, 0);
    ans += calc(root);
    vis[root] = true;

    travel(e, edge[root]) if(!vis[v])
    {
        ans -= calc(v, ww[e]);
        dfs(v);
    }
}

void work()
{
    while(scanf("%d%d", &n, &k) == 2 && (n || k))
    {
        init();

        ff(i, n-1)
        {
            int u, v, w;
            scanf("%d%d%d", &u, &v, &w);

            addEdge(u, v, w, edge);
            addEdge(v, u, w, edge);
        }

        dfs(1);
        cout << ans << endl;
    }
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章