首先我們先求出每個點到葉節點的最大距離。從編號爲1的節點開始,存在一個數組中。
然後在用rmq預處理這個數組的最大值和最小值
最後二分查找答案。
現在我們討論如何解決這3步
第一步:樹形dp,需要兩次遍歷。第一次遍歷,任取一個節點爲根節點u,樹的層次就確定了。可以遞歸求出每個點到其子孫的最遠距離。每個節點記錄兩個最值,最大值和次大值,並記錄這兩個值來自哪個節點。注意:這兩個值來自不同的子樹中。然後進行第二次遍歷,從u開始遍歷子節點。子節點的最值要麼來自其子孫中,要麼來自其父節點的最值中。這裏我們可以用父節點的最值來更新子節點的最值。若父節點最值來自這個子節點,則用次大值更新,否則用最大值更新。最後我們把每個節點的最大值按編號從1-n存在一個數組中。
第二步:rmq預處理出數組的最大值和最小值,這一步不難。
第三部:題目要求差值不大於Q,所以我們枚舉人數,看是否符合,取最大值。
複雜度分析:
第一步求最大距離,2次遍歷都是O(n),第二部rmq預處理爲O(nlgn),第三部查詢最多有500,二分答案lgn,然後枚舉起點需要O(n-m),m爲長度,最壞情況爲:500*lgn*(n-m)。
優化:
二分查找時,每次每個長度求出後即記錄下來,下一次再訪問時可以避免枚舉,直接得到答案。不加會TLE。
Rmq查詢時,計算區間長度的log值,以前都調用函數庫的log函數,數學函數處理double值,速度一般比較慢,在多查詢中體現的明顯,後來預處理出來,然後就過了,不然會TLE。
程序最後優化的比較快,265MS,在杭電排Rank1
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int N = 51000;
const int M = 17;
const int eps = 1e-9;
int d[N][2], Fm[N][2], head[N];
int mi[M][N], mx[M][N];
int Log[N];
int ans[N];
struct EDG{
int u, v, c, next;
}g[N * 2];
int cnt, n;
void add(int u, int v, int c)
{
g[cnt].u = u; g[cnt].v = v; g[cnt].c = c; g[cnt].next = head[u]; head[u] = cnt++;
g[cnt].u = v; g[cnt].v = u; g[cnt].c = c; g[cnt].next = head[v]; head[v] = cnt++;
}
int dfs(int cur, int p)
{
for(int e = head[cur]; e != -1; e = g[e].next){
int v = g[e].v;
if(v != p){
int t = dfs(v, cur) + g[e].c;
if(d[cur][1] < t){
d[cur][0] = d[cur][1]; Fm[cur][0] = Fm[cur][1];
d[cur][1] = t; Fm[cur][1] = v;
}else if(d[cur][0] < t){
d[cur][0] = t; Fm[cur][0] = v;
}
}
}
return d[cur][1];
}
void Dp(int cur, int p, int c)
{
if(cur != Fm[p][1]){
if(d[cur][1] < d[p][1] + c){
d[cur][0] = d[cur][1]; Fm[cur][0] = Fm[cur][1];
d[cur][1] = d[p][1] + c; Fm[cur][1] = p;
}else if(d[cur][0] < d[p][1] + c){
d[cur][0] = d[p][1] + c; Fm[cur][0] = p;
}
}else {
if(d[cur][1] < d[p][0] + c){
d[cur][0] = d[cur][1]; Fm[cur][0] = Fm[cur][1];
d[cur][1] = d[p][0] + c; Fm[cur][1] = p;
}else if(d[cur][0] < d[p][0] + c){
d[cur][0] = d[p][0] + c; Fm[cur][0] = p;
}
}
for(int e = head[cur]; e != -1; e = g[e].next){
int v = g[e].v;
if(v != p){
Dp(v, cur, g[e].c);
}
}
}
void ST()
{
for(int i = 1; i <= n; i++)
mi[0][i] = mx[0][i] = d[i][1];
for(int i = 1; (1 << i) < n; i++){
for(int j = n; j >= 1; j--){
mx[i][j] = mx[i - 1][j];
if(j + (1 << (i - 1)) <= n)
mx[i][j] = max(mx[i][j], mx[i - 1][j + (1 << (i - 1))]);
mi[i][j] = mi[i - 1][j];
if(j + (1 << (i - 1)) <= n)
mi[i][j] = min(mi[i][j], mi[i - 1][j + (1 << (i - 1))]);
}
}
}
inline int rmq(int l, int r)
{
int m = Log[r - l + 1];
int a = max(mx[m][l], mx[m][r - (1 << m) + 1]);
int b = min(mi[m][l], mi[m][r - (1 << m) + 1]);
return a - b;
}
inline int check(int m)
{
int &res = ans[m];
if(res != -1) return res;
for(int i = 1; i + m - 1 <= n; i++){
if(res == -1) res = rmq(i, i + m - 1);
else res = min(res, rmq(i, i + m - 1));
}
return res;
}
int main()
{
//freopen("input.txt", "r", stdin);
int m;
int u, v, c;
for(int i = 0, j = 0; i <= 50000; i++){
if((1 << j) >= ((i + 1) >> 1)){
Log[i] = j;
}else{
Log[i] = ++j;
}
}
while(scanf("%d %d", &n, &m) == 2 && m + n){
cnt = 0;
memset(head, -1, sizeof(head));
memset(d, 0, sizeof(d));
memset(Fm, 0, sizeof(Fm));
for(int i = 1; i < n; i++){
scanf("%d %d %d", &u, &v, &c);
add(u, v, c);
}
dfs(1, 0);
Dp(1, 0, 0);
ST();
memset(ans, -1, sizeof(ans));
ans[1] = 0;
for(int i = 0; i < m; i++){
scanf("%d", &u);
int l = 1, r = n + 1;
while(l < r){
int mid = (l + r) >> 1;
if(check(mid) <= u){
l = mid + 1;
}else r = mid;
}
printf("%d\n", l - 1);
}
}
return 0;
}