BZOJ 3197 assassin(樹形DP+費用流)

題目鏈接:BZOJ 3197

題目大意:給出兩棵節點被染成黑白兩色的無根樹,問第一棵樹經過重標號後至少要反轉多少個節點的顏色使之與第二棵樹完全相同。

題解:類似BZOJ3162獨釣寒江雪 的解法,可以將樹的重心作爲根DP,設f[i][j]表示若使第一棵樹中以i爲根的子樹和第二棵樹中以j爲根的子樹完全相同需要反轉至少多少個節點的顏色。轉移的時候對於同構的子樹用費用流轉移(還是比較好理解的,詳見代碼)。

code(第一次寫費用流轉移的DP,有參考大牛們的BLOG)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#define N 705
#define seed 9875321
#define inf 1000000000
using namespace std;
inline int read()
{
    char c=getchar(); int num=0,f=1;
    while (c<'0'||c>'9') { if (c=='-') f=-1; c=getchar(); }
    while (c<='9'&&c>='0') { num=num*10+c-'0'; c=getchar(); }
    return num*f;
}
struct edge{
    int to,ne;
    bool del;
}e[N<<1];
int n,tot=1,head[N],root,fa[N],siz[N],f[N][N],tmp[N],a[N],b[N],deep[N],mn=inf,h[2];
inline void push(int x,int y) { e[++tot].to=y; e[tot].ne=head[x]; head[x]=tot; }
unsigned long long H[N];
pair<int,pair<unsigned long long,int> > w[N];
pair<unsigned long long,int> p1[N],p2[N];
struct Network{
    int S,T,tot,head[25],dis[25],pre[25]; bool vis[25];
    struct edge{
        int fr,to,ne,c,v;
    }e[450];
    void clear(int n) { S=0,T=n+1; tot=1; for (int i=S;i<=T;i++) head[i]=0; }
    void push(int x,int y,int flow,int cost)
    {
        e[++tot].fr=x; e[tot].to=y; e[tot].v=flow; e[tot].c=cost; e[tot].ne=head[x]; head[x]=tot;
        e[++tot].fr=y; e[tot].to=x; e[tot].v=0; e[tot].c=-cost; e[tot].ne=head[y]; head[y]=tot;
    }
    bool spfa()
    {
        for (int i=S;i<=T;i++) dis[i]=inf;
        queue<int> q; dis[S]=0; vis[S]=true; q.push(S);
        while (!q.empty())
        {
            int now=q.front(); q.pop();
            for (int i=head[now];i;i=e[i].ne)
            {
                int v=e[i].to;
                if (e[i].v&&dis[now]+e[i].c<dis[v])
                {
                    dis[v]=dis[now]+e[i].c;
                    pre[v]=i;
                    if (!vis[v]) vis[v]=true,q.push(v);
                }
            }
            vis[now]=false;
        }
        return dis[T]!=inf;
    }
    int mcf()
    {
        for (int i=T;i!=S;i=e[pre[i]].fr) e[pre[i]].v--,e[pre[i]^1].v++;
        return dis[T];
    }
}flow;
void getrt(int now,int pre)
{
    siz[now]=1; int tmp=0;
    for (int i=head[now];i;i=e[i].ne)
    {
        int v=e[i].to; if (v==pre) continue;
        getrt(v,now); siz[now]+=siz[v];
        tmp=max(tmp,siz[v]);
    }
    tmp=max(tmp,n-siz[now]);
    if (tmp<mn) mn=tmp,h[0]=now,h[1]=0;
     else if (tmp==mn) h[1]=now;
}
void geth(int now,int pre,int dep)
{
    fa[now]=pre; deep[now]=dep;
    for (int i=head[now];i;i=e[i].ne)
    {
        int v=e[i].to; if (v==pre||e[i].del) continue;
        geth(v,now,dep+1);
    }
    int top=0;
    for (int i=head[now];i;i=e[i].ne)
    {
        int v=e[i].to; if (v==pre||e[i].del) continue;
        tmp[++top]=H[v];
    }
    sort(tmp+1,tmp+top+1);
    H[now]=233;
    for (int i=1;i<=top;i++) (((H[now]*=seed)^=tmp[i])+=tmp[i])^=tmp[i];
}
void solve(int x,int y)
{
    int s1=0,s2=0;
    for (int i=head[x];i;i=e[i].ne)
    {
        int v=e[i].to; if (v==fa[x]||e[i].del) continue;
        p1[++s1]=make_pair(H[v],v);
    }
    for (int i=head[y];i;i=e[i].ne)
    {
        int v=e[i].to; if (v==fa[y]||e[i].del) continue;
        p2[++s2]=make_pair(H[v],v);
    }
    sort(p1+1,p1+s1+1); sort(p2+1,p2+s2+1);
    for (int i=1;i<=s1;i++)
    {
        int j=i;
        while (j<s1&&p1[j+1].first==p1[j].first) j++;
        int len=j-i+1;
        flow.clear(len*2);
        for (int k=i;k<=j;k++)
         for (int l=i;l<=j;l++)
          flow.push(k-i+1,l-i+1+len,1,f[p1[k].second][p2[l].second]);
        for (int k=1;k<=len;k++)
         flow.push(flow.S,k,1,0),flow.push(k+len,flow.T,1,0);
        while (flow.spfa()) f[x][y]+=flow.mcf();
        i=j;
    }
    if (a[x]!=b[y]) f[x][y]++;
}
int main()
{
    n=read();
    for (int i=1;i<n;i++)
    {
        int x=read(),y=read();
        push(x,y); push(y,x);
    }
    for (int i=1;i<=n;i++) a[i]=read();
    for (int i=1;i<=n;i++) b[i]=read();
    getrt(1,0);
    if (h[1])
    {
        for (int i=head[h[0]];i;i=e[i].ne)
         if (e[i].to==h[1]) e[i].to=e[i^1].to=root=n+1;
        push(n+1,h[0]); push(n+1,h[1]); n++;
    }
    else root=h[0];
    geth(root,0,1);
    for (int i=1;i<=n;i++) w[i]=make_pair(-deep[i],make_pair(H[i],i));
    sort(w+1,w+n+1);
    for (int i=1;i<=n;i++)
    {
        int j=i;
        while (j<n&&w[j+1].first==w[j].first&&w[j+1].second.first==w[j].second.first) j++;
        for (int k=i;k<=j;k++)
         for (int l=i;l<=j;l++)
          solve(w[k].second.second,w[l].second.second);
        i=j;
    }
    printf("%d",f[root][root]);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章