題面
題意
給你一個長度爲n的字符串,現要你加上m個字符,使其變爲一個迴文串,問有幾種加法。
做法
首先題目可以轉化爲,求有幾個長度爲n+m的字符串,使給出的字符串爲該字符串的子序列。
這樣可以考慮從兩邊開始確定字符,並與給出的字符串進行匹配,然後我們就可以根據此時的字符串已經匹配的位置建立自動機,這個自動機由多個節點構成,每個點都有一個權值爲24,25或26(僅終點是26)的自環,然後非自環會形成一個DAG,並且每條從起點出發的路徑,都是由幾個自環權值爲25或24的節點連接而成,最終通向自環權值爲26的目標點,對於這個性質可以將自動機的點壓縮至O(n)級別,新建n-1個自環權值爲24的點,標號並對相鄰的點連權值爲1的有向邊,再建(n+1)/2個自環權值爲25的點,標號並對相鄰的節點連權值爲1的有向邊,第(n+1)/2個點再向匯點(自環權值爲26)連一條權值爲1的有向邊,然後再在權值爲24,25的節點之間連邊,其權值就爲從起點開始經過a個24的點,再經過b個25的點的方案數(注意a可以爲0,可以由源點直接連向25的點),這個方案數可以用的dp求出。
然後發現若m+n爲奇數,則若最後一步在原串中還有兩個字符未匹配(且這兩個字符相同),不能直接轉移到最終狀態,可以用dp求出這樣的方案數,然後用上述矩陣,以這種不合法的狀態爲終點,減去不合法的狀態數量。
這題有點卡常,因爲這裏的矩陣都是上三角矩陣,所以矩陣乘法可以寫成這樣來減小常數
Jz operator * (const Jz &u) const
{
int i,j,k;
Jz res;
for(i=0; i<=L; i++)
{
for(j=i; j<=L; j++)
{
for(k=i; k<=L; k++)
{
Add(res.num[i][j],num[i][k]*u.num[k][j]%M);
}
}
}
return res;
}
代碼
#include<bits/stdc++.h>
#define N 310
#define M 10007
using namespace std;
int n,m,L,ans,dp[N][N][N],sum[N];
char str[N];
inline void Add(int &u,int v){u+=v,u%=M;}
struct Jz
{
int num[N][N];
Jz(){memset(num,0,sizeof(num));}
void clear(){memset(num,0,sizeof(num));}
Jz operator * (const Jz &u) const
{
int i,j,k;
Jz res;
for(i=0; i<=L; i++)
{
for(j=i; j<=L; j++)
{
for(k=i; k<=L; k++)
{
Add(res.num[i][j],num[i][k]*u.num[k][j]%M);
}
}
}
return res;
}
} dw,st,an;
inline Jz po(Jz u,int v)
{
int i;
Jz res;
for(i=0;i<=L;i++) res.num[i][i]=1;
for(;v;)
{
if(v&1) res=res*u;
u=u*u;
v>>=1;
}
return res;
}
int main()
{
int i,j,k,t;
scanf("%s%d",str+1,&m);
n=strlen(str+1);
dp[1][n][0]=1;
for(i=1; i<=n; i++)
{
for(j=n; j>=i; j--)
{
if(str[i]==str[j])
{
for(k=0; k<=n; k++)
{
if(!dp[i][j][k]) continue;
if(j-i>1) Add(dp[i+1][j-1][k],dp[i][j][k]);
else Add(sum[k],dp[i][j][k]);
}
}
else
{
for(k=0; k<=n; k++)
{
if(!dp[i][j][k]) continue;
Add(dp[i+1][j][k+1],dp[i][j][k]);
Add(dp[i][j-1][k+1],dp[i][j][k]);
}
}
}
}
st.num[0][1]=1;
st.num[0][n]=sum[0];
t=(n+1)/2;
L=n+t;
for(i=1;i<n;i++)
{
dw.num[i][L-(n-i+1)/2]=sum[i];
dw.num[i][i]=24;
if(i<n-1) dw.num[i][i+1]=1;
}
for(i=n;i<L;i++)
{
dw.num[i][i]=25;
dw.num[i][i+1]=1;
}
dw.num[L][L]=26;
an=st*po(dw,(n+m+1)>>1);
ans=an.num[0][L];
if((n+m)&1)
{
dw.clear();
memset(sum,0,sizeof(sum));
for(i=1;i<n;i++)
{
if(str[i]!=str[i+1]) continue;
for(j=0;j<=n;j++)
{
Add(sum[j],dp[i][i+1][j]);
}
}
st.num[0][n]=sum[0];
for(i=1;i<n;i++)
{
dw.num[i][L-(n-i+1)/2]=sum[i];
dw.num[i][i]=24;
if(i<n-1) dw.num[i][i+1]=1;
}
for(i=n;i<L;i++)
{
dw.num[i][i]=25;
dw.num[i][i+1]=1;
}
an=st*po(dw,(n+m+1)>>1);
Add(ans,M-an.num[0][L]);
}
cout<<ans;
}