這道題思路並不難想(假的!):先用manacher算法求出以s[i]爲中心的最長迴文子串左右擴展的長度,再分別推出以s[i]結尾和開頭的迴文子串(注:不一定是最長的。如原串爲aacaa,則i=3時,en[i]=2而不是1)數量,然後其中一組乘上另外一組的後綴和(前綴和)相加即可。簡單分析一下第一個樣例:
原數組的下標 i: 0 1 2
a c a
以s[i]結尾的迴文串數量: 1 1 2 -> en[]數組
以s[i]開頭的迴文串數量: 2 1 1 -> st[]數組
則答案爲1×(1+1) + 1×1 = 3
問題主要是如何推出st,en數組。我一開始是這麼寫的:
for(int i=2;i<l;i++)
{
int k=i+p[i]-1;//最長迴文串擴展的右邊界(a數組中)
int j=ceil((i*1.0-2)/2);//迴文串中心(映射到原s數組下標)
while(j<=(k-2)/2)//同樣要映射到s數組的下標
en[j]++,j++;
}
for(int i=0,j=len-1;i<len&&j>=0;i++,j--)
st[i]=en[j];
兩層循環果斷TLE。。只能換思路。。後來網上一搜發現可以用樹狀數組來維護。
先附上大佬博客Orz:https://www.cnblogs.com/liyinggang/p/5675916.html
https://blog.csdn.net/hexianhao/article/details/51823113
以en數組爲例,假設以i爲中心(a[i]不一定是字母,要“映射”到新下標(從1開始)),則從i到i+p[i]-1這些點的en值都要加1。我不太明白爲什麼把到右端點的數-1,再把到中心的數+1(我感覺剛好反了)。後來我寫了一發但還是WA了QAQ...不明白爲什麼會醬紫55555...可能還是對樹狀數組的理解不夠?把solve函數放在這裏,如果有大神知道爲什麼錯了還請不吝賜教^_^
ll st[MAX];//以s[i]開頭的迴文子串數
ll en[MAX];//以s[i]結尾的迴文子串數
ll c[MAX];
ll sum[MAX];
int lowbit(int x)
{
return x&(-x);
}
void update(int i,int val)
{
while(i<=n)
{
c[i]+=val;
i+=lowbit(i);
}
}
int get_sum(int i)
{
int ret=0;
while(i>0)
{
ret+=c[i];
i-=lowbit(i);
}
return ret;
}
void fun()
{
memset(c,0,sizeof(c));
for(int i=2;i<l;i++)
{
int mid;//迴文串中心(映射到原s數組下標,但注意下標從1開始)
if(i%2==0)//字母
mid=i/2;
else //'#'
mid=(i+1)/2;
int k=i+p[i]-1;//最長迴文串擴展的右邊界(a數組中)
int r=k/2+1;//映射到原s數組
if(r<=mid)
continue;
//cout<<"i="<<i<<" l="<<l<<" r="<<r<<endl;
update(mid,1);
update(r,-1);
//c[mid,r]+1
}
}
void solve() //得到st,en數組
{
memset(st,0,sizeof(st));
memset(en,0,sizeof(en));
fun();
for(int i=1;i<=n;i++)
en[i]=get_sum(i);
/*for(int i=1;i<=n;i++)
cout<<"i="<<i<<" en[i]="<<en[i]<<endl;*/
reverse(s,s+n);
manacher();
fun();
for(int i=1;i<=n;i++)
st[i]=get_sum(i);
}
我還是選擇繼續掙扎。。再一搜題解,發現可以用“差分前綴和”(第一次聽說QAQ...)
附上講解博客Orz:https://blog.csdn.net/hzk_cpp/article/details/80407014
https://www.cnblogs.com/lulizhiTopCoder/p/8384784.html
感覺有點像樹狀數組,處理思路也差不多。再附上本題的參考博客Orz:
https://blog.csdn.net/gatevin/article/details/44775533
經過長期掙扎。。終於過了。。注意WA點開long long啊!!!!!附上AC代碼:
#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<stack>
#include<queue>
using namespace std;
#define ll long long
typedef pair<int,int>pp;
#define mkp make_pair
#define pb push_back
const int INF=0x3f3f3f3f;
const ll MOD=1e9+(ll)7;
const int MAX=100010;
char s[MAX];
int n;
char a[MAX*2];
int len,p[MAX*2];//以s[i]爲中心的最長迴文子串右(左)擴展的長度
void manacher()
{
memset(p,0,sizeof(p));
len=0;
a[len++]='$';
a[len++]='#';
for(int i=0;i<n;i++)
{
a[len++]=s[i];
a[len++]='#';
}
a[len]='\0';
int mx=0,id=0;
for(int i=0;i<len;i++)
{
p[i]=(mx-i)?min(p[2*id-i],mx-i):1;
while(a[i+p[i]]==a[i-p[i]])
p[i]++;
if(i+p[i]>mx)
{
mx=i+p[i];
id=i;
}
}
/*for(int i=0;i<len;i++)
cout<<a[i]<<" ";
cout<<endl;
for(int i=0;i<len;i++)
cout<<p[i]<<" ";
cout<<endl;
cout<<"len="<<len<<endl;*/
}
ll st[MAX];//以s[i]開頭的迴文子串數
ll en[MAX];//以s[i]結尾的迴文子串數
ll dp1[MAX],dp2[MAX];//差分
ll sum[MAX];//st的後綴和
void solve() //得到st,en數組
{
memset(dp1,0,sizeof(dp1));
memset(dp2,0,sizeof(dp2));
for(int i=2;i<len;i++)
{
int l,r;
l=i-(p[i]-1);r=i;//a數組下標,要映射到s數組
if(l%2) l/=2;
else l=l/2-1;
r=r/2-1;
if(l<=r)
dp1[l]++,dp1[r+1]--;//對應st數組
l=i;r=i+(p[i]-1);
if(l%2) l/=2;
else l=l/2-1;
r=r/2-1;
if(l<=r)
dp2[l]++,dp2[r+1]--;//對應en數組
}
st[0]=dp1[0];en[0]=dp2[0];
for(int i=1;i<n;i++)
{
st[i]=st[i-1]+dp1[i];
en[i]=en[i-1]+dp2[i];
}
/*for(int i=0;i<n;i++)
cout<<en[i]<<" ";
cout<<endl;
for(int i=0;i<n;i++)
cout<<st[i]<<" ";
cout<<endl;*/
}
int main()
{
while(scanf("%s",s)==1)
{
n=strlen(s);
manacher();
solve();
sum[n-1]=st[n-1];//st的後綴和
for(int i=n-2;i>=0;i--)
sum[i]=sum[i+1]+st[i];
ll ans=0;
for(int i=0;i<n-1;i++)
{
ans+=en[i]*sum[i+1];
}
printf("%lld\n",ans);
}
return 0;
}
另外這道題更普遍的解法是迴文樹,但是我不會啊55555...回頭有時間好好學習一下,再做一下這道題吧。