動態 DP 總結

注:部分參考 https://www.luogu.org/blog/gkxx-is-here/what-the-hell-is-ddp

動態DP,就是一個十分簡單的DP加了一個修改操作。
先看些例題:

例題1:模擬賽題

【問題描述】
某高校教學樓有 n 層,每一層有 2 個門,每層的兩個門和下一層之間的兩個門之間各有一條路(共 4 條),相同層的 2 個門之間沒有路。現給出如下兩個操作:
0 x y : 查詢第 x 層到第 y 層的路徑數量。
1 x y z : 改變第 x 層 的 y 門 到第 x+1 層的 z 門的通斷情況。
【輸入】
輸入文件名爲(road.in)。
第一行:兩個正整數 n m,表示共 n 層,m 個操作(2≤n≤50000,1≤m≤50000)接下來 m 行,
當第一個數爲 0 的時候 後面有兩個數 a,b (1≤a<b≤n)表示詢問第 a 層到第 b 層的路徑數量。
第一個數爲 1 的時候,後面有三個數 x, y, z (1≤x<n,1≤y,z≤2)表示改變第 x 層 的 y 門 到第 x+1 層的 z 門的通斷情況。
【輸出】
輸出文件名爲(road.out)。
輸出每一個詢問值。答案對 10^9+7 取模

這是最簡單的動態DP。
首先,發現有修改和詢問,而詢問又是區間查詢,自然想到線段樹維護。
直接的DP,肯定難以維護。考慮將\(dp_i到dp_{i+1}\)的變換轉化爲一個簡單的操作。
這是個計數問題,只有求和,顯然可以變爲矩陣乘法。就是\(dp_{i+1}\)等於\(dp_i\)乘一個矩陣。
這樣,通過矩乘優化,這個dp轉化爲了一段矩陣的乘積。
於是,問題變爲:有一些矩陣,支持修改一個矩陣,和查詢區間矩陣乘積。
線段樹很容易維護。

代碼:

#include <stdio.h>
#define ll long long
ll md=1000000007;
struct SJz
{
    ll jz[2][2];
    SJz operator*(SJz sz);
    void operator=(SJz sz)
    {
        jz[0][0]=sz.jz[0][0];
        jz[0][1]=sz.jz[0][1];
        jz[1][0]=sz.jz[1][0];
        jz[1][1]=sz.jz[1][1];
    }
};
SJz rtt,dw;
SJz SJz::operator*(SJz sz)
{
    for(int i=0;i<2;i++)
    {
        for(int j=0;j<2;j++)
        {
            rtt.jz[i][j]=0;
            for(int k=0;k<2;k++)
                rtt.jz[i][j]=(rtt.jz[i][j]+jz[i][k]*sz.jz[k][j])%md;
        }
    }
    return rtt;
}
SJz zh[200010];
void pushup(int i)
{
    zh[i]=zh[i<<1]*zh[(i<<1)|1];
}
void jianshu(int i,int l,int r)
{
    if(l+1==r)
    {
        zh[i].jz[0][0]=zh[i].jz[0][1]=zh[i].jz[1][0]=zh[i].jz[1][1]=1;
        return;
    }
    int m=(l+r)>>1;
    jianshu(i<<1,l,m);
    jianshu((i<<1)|1,m,r);
    pushup(i);
}
void xiugai(int i,int l,int r,int j,int x,int y)
{
    if(l+1==r)
    {
        zh[i].jz[x][y]^=1;
        return;
    }
    int m=(l+r)>>1;
    if(j<m)
        xiugai(i<<1,l,m,j,x,y);
    else
        xiugai((i<<1)|1,m,r,j,x,y);
    pushup(i);
}
SJz chaxun(int i,int l,int r,int L,int R)
{
    if(L<=l&&r<=R)
        return zh[i];
    if(r<=L||R<=l)
        return dw;
    SJz t1,t2;
    int m=(l+r)>>1;
    t1=chaxun(i<<1,l,m,L,R);
    t2=chaxun((i<<1)|1,m,r,L,R);
    return t1*t2;
}
int main()
{
    freopen("road.in","r",stdin);
    freopen("road.out","w",stdout);
    int n,m;
    scanf("%d%d",&n,&m);
    jianshu(1,1,n+1);
    dw.jz[0][0]=dw.jz[1][1]=1;
    dw.jz[0][1]=dw.jz[1][0]=0;
    for(int i=0;i<m;i++)
    {
        int lx;
        scanf("%d",&lx);
        if(lx==1)
        {
            int a,x,y;
            scanf("%d%d%d",&a,&x,&y);
            xiugai(1,1,n+1,a,x-1,y-1);
        }
        else
        {
            int x,y;
            scanf("%d%d",&x,&y);
            SJz jg=chaxun(1,1,n+1,x,y);
            printf("%I64d\n",(jg.jz[0][0]+jg.jz[0][1]+jg.jz[1][0]+jg.jz[1][1])%md);
        }
    }
    return 0;
}

非常好理解的。
然而,這是計數dp,只有加和乘,容易矩乘,但是通常的dp還是有\(min,max\)操作的。

例題2

1752467-20190824214239145-1182200638.png

1752467-20190824214247692-702619127.png

和上題一樣,考慮將轉移表示爲矩乘,然後線段樹維護。
但是,矩乘沒有\(min,max\)操作。
我們重新定義新的矩乘,** 使其滿足結合律,以方便線段樹維護 **。
1752467-20190824215302635-1354415165.png

這樣,用類似上一題的方法維護即可。
沒有代碼。

例題3:帶修改樹上最大獨立集。

這個樹形DP轉移很簡單:
1752467-20190824214919899-1805008103.png

但是,這題是樹,有多個兒子,不方便矩乘。

通常,若序列上用線段樹,那麼樹上就是樹剖套線段樹。
但是,轉移時針對所有兒子而言的,而樹剖只有一個重兒子。
所以,我們只能額外維護一個信息g,表示一個節點只算上它和它的的輕兒子的dp值,然後再用這個g和它的重兒子的dp值算出它的dp值。
將這個\(dp_v\)\(dp_u\)的過程寫成矩陣乘法,矩陣中包含g。
這樣,某個點的dp值就是重鏈上矩陣的乘積,用線段樹維護。
考慮修改:dp是通過g計算的,所以維護g即可。而修改一個點後,只有輕邊上的g值會被修改,沿着重鏈跳到根即可。
時間複雜度\(O(2^3*qlog^2n)\),能過\(10^5\)

代碼:

(常數巨大)

#include <stdio.h>
#define max(a,b) ((a)>(b)?(a):(b))
int inf=2100000000;
struct SJz
{
    int jz[2][2];
    SJz(){}
    SJz(int a,int b,int c,int d)
    {
        jz[0][0]=a;jz[0][1]=b;
        jz[1][0]=c;jz[1][1]=d;
    }
    void operator=(SJz x)
    {
        jz[0][0]=x.jz[0][0];
        jz[0][1]=x.jz[0][1];
        jz[1][0]=x.jz[1][0];
        jz[1][1]=x.jz[1][1];
    }
};
SJz operator*(SJz x,SJz y)
{
    SJz rt;
    for(int i=0;i<2;i++)
    {
        for(int j=0;j<2;j++)
            rt.jz[i][j]=-inf;
    }
    for(int i=0;i<2;i++)
    {
        for(int j=0;j<2;j++)
        {
            for(int k=0;k<2;k++)
            {
                if(x.jz[i][j]!=-inf&&y.jz[j][k]!=-inf)
                    rt.jz[i][k]=max(x.jz[i][j]+y.jz[j][k],rt.jz[i][k]);
            }
        }
    }
    return rt;
}
int fr[100010],ne[200010],v[200010],bs=0;
void addb(int a,int b)
{
    v[bs]=b;
    ne[bs]=fr[a];
    fr[a]=bs++;
}
int fa[100010],son[100010],top[100010],dn[100010];
int xl[100010],sz[100010],wz[100010],tm=0;
int g0[100010],g1[100010],f0[100010],f1[100010];
int dfs1(int u,int f)
{
    fa[u]=f;son[u]=-1;
    int ma=0,he=1;
    for(int i=fr[u];i!=-1;i=ne[i])
    {
        if(v[i]==f)
            continue;
        int rt=dfs1(v[i],u);
        he+=rt;
        if(rt>ma)
        {
            ma=rt;
            son[u]=v[i];
        }
    }
    return he;
}
void dfs2(int u,int f,int tp)
{
    top[u]=tp;
    wz[u]=++tm;xl[wz[u]]=u;
    if(son[u]==-1)
    {
        dn[u]=u;
        return;
    }
    dfs2(son[u],u,tp);
    for(int i=fr[u];i!=-1;i=ne[i])
    {
        if(v[i]!=f&&v[i]!=son[u])
            dfs2(v[i],u,v[i]);
    }
    dn[u]=dn[son[u]];
}
void dfs3(int u,int f)
{
    f0[u]=0;f1[u]=sz[u];
    g0[u]=0;g1[u]=sz[u];
    for(int i=fr[u];i!=-1;i=ne[i])
    {
        if(v[i]==f)
            continue;
        dfs3(v[i],u);
        int r0=f0[v[i]],r1=f1[v[i]];
        if(v[i]!=son[u])
        {
            g0[u]+=max(r0,r1);
            g1[u]+=r0;
        }
        f0[u]+=max(r0,r1);
        f1[u]+=r0;
    }
}
SJz ji[400010],I(0,-inf,-inf,0);
void ddz(int i,int l,int r)
{
    int u=xl[l];
    ji[i]=SJz(g0[u],g1[u],g0[u],-inf);
}
void jianshu(int i,int l,int r)
{
    if(l+1==r)
    {
        ddz(i,l,r);
        return;
    }
    int m=(l+r)>>1;
    jianshu(i<<1,l,m);
    jianshu((i<<1)|1,m,r);
    ji[i]=ji[(i<<1)|1]*ji[i<<1];
}
void xiugai(int i,int l,int r,int j)
{
    if(l+1==r)
    {
        ddz(i,l,r);
        return;
    }
    int m=(l+r)>>1;
    if(j<m)
        xiugai(i<<1,l,m,j);
    else
        xiugai((i<<1)|1,m,r,j);
    ji[i]=ji[(i<<1)|1]*ji[i<<1];
}
SJz getsum(int i,int l,int r,int L,int R)
{
    if(R<=l||r<=L)
        return I;
    if(L<=l&&r<=R)
        return ji[i];
    int m=(l+r)>>1;
    return getsum((i<<1)|1,m,r,L,R)*getsum(i<<1,l,m,L,R);
}
void getf(int u,int &f0,int &f1)
{
    SJz rt=getsum(1,1,tm+1,wz[u],wz[dn[u]]);
    int r0=0,r1=sz[dn[u]];
    f0=max(r0+rt.jz[0][0],r1+rt.jz[1][0]);
    f1=max(r0+rt.jz[0][1],r1+rt.jz[1][1]);
}
void update(int u,int n0,int n1)
{
    if(fa[u]!=0)
    {
        g0[fa[u]]=g0[fa[u]]-max(f0[u],f1[u])+max(n0,n1);
        g1[fa[u]]=g1[fa[u]]-f0[u]+n0;
        xiugai(1,1,tm+1,wz[fa[u]]);
    }
    f0[u]=n0;f1[u]=n1;
}
void xiugai(int x,int y)
{
    g1[x]=g1[x]-sz[x]+y;
    sz[x]=y;
    xiugai(1,1,tm+1,wz[x]);
    while(x!=0)
    {
        x=top[x];
        int n0,n1;
        getf(x,n0,n1);
        update(x,n0,n1);
        x=fa[x];
    }
}
void build()
{
    dfs1(1,0);
    dfs2(1,0,1);
    dfs3(1,0);
    jianshu(1,1,tm+1);
}
int main()
{
    int n,m;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&sz[i]);
        fr[i]=-1;
    }
    for(int i=0;i<n-1;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        addb(a,b);addb(b,a);
    }
    build();
    for(int i=0;i<m;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        xiugai(x,y);
        printf("%d\n",max(f0[1],f1[1]));
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章