奇怪的FFT+manacher以及容斥(?)。
題目要求求出不連續的迴文序列,首先要想到用所有迴文序列減掉連續的迴文序列,連續的顯然可以用manacher求出來,於是題目轉化爲求出所有的迴文序列。巨大的腦洞:由於題目所給的字符串只包含a和b兩種字符,所以我們將a看成1,b看成0做一遍FFT,然後將b看成1,a看成0再做一遍,就可以求出所有的對稱字符對的對稱中心下標乘2包含的對稱字符對(大概理解一下)。這中間有個問題:對稱中心有可能在字符上,也有可能在兩個字符之間。這在FFT的作用下導致了一個問題:對於兩個下標不同的字符對,它們會對其對稱中心下標乘2的位置帶來2的貢獻;而對於同一個字符,它會給自己帶來1的貢獻。對於這個問題的處理,網上大多數的題解都是將貢獻除以2然後上取整,於是導致我懵逼了一天,最後在zyz神犇的引導下解開了謎團;謎團的本質在之前已經說了,就不再詳細贅述做法了。其餘見代碼。
CODE:
#include<cmath>
#include<cstdio>
#include<cstring>
#include<complex>
#include<iostream>
using namespace std;
#define mod 1000000007ll
typedef long long ll;
typedef complex<double> cp;
const double pi=acos(-1);
const int N=3e5+10;
cp a[N],b[N];
char s[N];
int r[N],Ans[N],p[N],power[N];
int len,n,m,Len;
ll ans;
inline int getstring()
{
char c=getchar();int len=0;
while(c!='a'&&c!='b') c=getchar();
while(c=='a'||c=='b') s[++len]=c,c=getchar();
return len;
}
inline void fft(cp *a,int f)
{
for(int i=0;i<n;i++)
if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=1;i<n;i<<=1)
{
cp wn(cos(pi/i),f*sin(pi/i));
for(int j=0,tmp=i<<1;j<n;j+=tmp)
{
cp w(1,0);
for(int k=0;k<i;k++,w*=wn)
{
cp x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y,a[j+k+i]=x-y;
}
}
}
}
inline ll manacher()
{
int mx=0,pos=0;
ll ans=0;
for(int i=len;i;i--)
s[i<<1]=s[i],s[(i<<1)-1]='#';
len<<=1;
s[len+1]=s[len+2]='#',s[0]='$';
for(int i=1;i<=len;i++)
{
if(mx>i) p[i]=min(p[pos*2-i],mx-i);
else p[i]=1;
while(s[i-p[i]]==s[i+p[i]]) p[i]++;
if(i+p[i]>mx) mx=i+p[i],pos=i;
ans+=1ll*p[i]>>1;
}
return ans;
}
int main()
{
len=getstring();m=len<<1;
for(n=1;n<m;n<<=1) Len++;
for(int i=0;i<n;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(Len-1));
for(int i=1;i<=len;i++)
if(s[i]=='a') a[i-1]=1;
else b[i-1]=1;
fft(a,1),fft(b,1);
for(int i=0;i<n;i++)
a[i]*=a[i],b[i]*=b[i];
fft(a,-1),fft(b,-1);
power[0]=1;
for(int i=1;i<=m;i++)
{
power[i]=power[i-1]<<1;
if(power[i]>=mod) power[i]-=mod;
}
for(int i=0;i<m;i++)
{
int tmp=(int)(a[i].real()/n+0.5)+(int)(b[i].real()/n+0.5)+1;
ans+=power[tmp>>1]-1;
if(ans>=mod) ans-=mod;
}
/*和上一個for的本質相同
for(int i=0;i<m;i++)
{
if(i&1)
{
int tmp=(int)(a[i].real()/n+0.5)+(int)(b[i].real()/n+0.5);
ans+=power[tmp>>1]-1;
}
else
{
int tmp=(int)(a[i].real()/n+0.5)+(int)(b[i].real()/n+0.5);
ans+=power[(tmp>>1)+1]-1;
}
if(ans>=mod) ans-=mod;
}
*/
printf("%lld",((ans-manacher())%mod+mod)%mod);
return 0;
}