Rainbow Roads(dfs序 + 差分)

題目鏈接:Rainbow Roads

題意:給你一個 n 個節點的樹,給你 n - 1 條邊,以及每條邊的顏色。
定義彩虹路:路上的相鄰的兩條邊顏色不同。
定義彩虹點:以這個點爲起點,其他所有點爲終點的所有路都是彩虹路。
問你有多少個彩虹點並從小到大輸出編號。

#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define lowbit(i) i & (-i)
using namespace std;

typedef long long ll;

const int N = 50010;

struct node
{
    int to, next, w;
} edge[N * 2];

int n;
int a[N], top1;
int b[N];
int in[N];
int vis[N];
int out[N], top2;
int head[N], tot;
int ans[N], top;

void add(int u, int v, int w)
{
    edge[tot].to = v;
    edge[tot].w = w;
    edge[tot].next = head[u];
    head[u] = tot++;
}

void get_dfs(int u, int pre)
{
    a[++top1] = u;
    in[u] = ++top2;
    for(int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        if(v != pre) get_dfs(v, u);
    }
    out[u] = top2;
}

void dfs(int u, int pre, int col)
{
    for(int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        int w = edge[i].w;
        if(v != pre) vis[w]++;
    }
    for(int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        int w = edge[i].w;
        if(v != pre)
        {
            if(w == col)
            {
                int now1 = in[v];
                int now2 = out[v];
                b[now1]--;
                b[now2 + 1]++;
                now1 = in[1];
                now2 = in[u] - 1;
                b[now1]--;
                b[now2 + 1]++;
                now1 = out[u] + 1;
                now2 = top1;
                b[now1]--;
                b[now2 + 1]++;
            }
            if(vis[w] > 1)
            {
                int now1 = in[v];
                int now2 = out[v];
                b[now1]--;
                b[now2 + 1]++;
            }
        }
    }
    for(int i = head[u]; i != -1; i = edge[i].next) vis[edge[i].w] = 0;
    for(int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        int w = edge[i].w;
        if(v != pre)
        {
            dfs(v, u, w);
        }
    }
}

int main()
{
    tot = 0;
    memset(head, -1, sizeof(head));
    int u, v, w;
    scanf("%d", &n);
    for(int i = 1; i < n; i++)
    {
        scanf("%d %d %d", &u, &v, &w);
        add(u, v, w);
        add(v, u, w);
    }
    get_dfs(1, 0);
    dfs(1, 0, 0);
    ll sum = 0;
    for(int i = 1; i <= top1; i++)
    {
        sum += b[i];
        if(sum >= 0)
        {
            ans[top++] = a[i];
        }
    }
    sort(ans, ans + top);
    printf("%d\n", top);
    for(int i = 0; i < top; i++) printf("%d\n", ans[i]);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章