題目傳送門
題意:
給你兩個字符串 ,長度分別是 。
讓你從兩個字符串中各自挑出一個子串,問這兩個子串相同的方案數。
數據範圍: 。
題解:
把兩個字符串拼接起來,中間需要連接一個沒有用過的字符,記作#。
連接後的字符串記作 。
設 表示 字符串 中挑出兩個子串相同的方案數。
這道題答案是 。容斥想一想就好了。
考慮怎麼求 。
計數題就要考慮倒序height + 並查集。這算是套路吧。
我們倒序排列 後。
連接 。
在合併時,我們累加 就好了。
感受:
遇到計數題要考慮一下倒序height + 並查集。
這個思路真好用。
代碼:
#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
const int maxn = 4e5 + 5 ;
int rk[maxn << 1] , sa[maxn << 1] , height[maxn << 1] ;
int tmp[maxn << 1] , cnt[maxn] ;
char s[maxn] , t[maxn] ;
int pre[maxn] , siz[maxn] ;
struct node
{
int x , id ;
bool operator < (const node &s) const
{
if(x != s.x) return x > s.x ;
else return id < s.id ;
}
} h[maxn] ;
void suffixarray(int n , int m)
{
n ++ ;
for(int i = 0 ; i < n * 2 + 5 ; i ++)
rk[i] = sa[i] = height[i] = tmp[i] = 0 ;//開2 倍空間
for(int i = 0 ; i < m ; i ++) cnt[i] = 0 ;
for(int i = 0 ; i < n ; i ++) cnt[rk[i] = s[i]] ++ ;
for(int i = 1 ; i < m ; i ++) cnt[i] += cnt[i - 1] ;
for(int i = 0 ; i < n ; i ++) sa[-- cnt[rk[i]]] = i ;
for(int k = 1 ; k <= n ; k <<= 1)
{
int j = 0 ;
for(int i = 0 ; i < n ; i ++)
{
j = sa[i] - k ;
if(j < 0) j += n ;
tmp[cnt[rk[j]] ++] = j ;
}
sa[tmp[cnt[0] = 0]] = j = 0 ;
for(int i = 1 ; i < n ; i ++)
{
if(rk[tmp[i]] != rk[tmp[i - 1]]
|| rk[tmp[i] + k] != rk[tmp[i - 1] + k])
cnt[++ j] = i ;
sa[tmp[i]] = j ;
}
memcpy(rk , sa , n * sizeof(int)) ;
memcpy(sa , tmp , n * sizeof(int)) ;
if(j >= n - 1) break ;
}
height[0] = 0 ;
for(int i = 0 , k = 0 , j = rk[0] ; i < n - 1 ; i ++ , k ++)
while(~k && s[i] != s[sa[j - 1] + k])
height[j] = k -- , j = rk[sa[j] + 1] ;
}
int find(int u)
{
if(pre[u] == u) return u ;
return pre[u] = find(pre[u]) ;
}
ll join(int k , int x , int y)
{
ll ans = 0 ;
int fx = find(x) ;
int fy = find(y) ;
if(fx != fy)
{
ans += ll(k) * siz[fx] * siz[fy] ;
pre[fx] = fy ;
siz[fy] += siz[fx] ;
}
//cout << k << ' ' << x << ' ' << y << ' ' << ans << '\n' ;
return ans ;
}
ll solve1()
{
int len ;
ll ans = 0 ;
scanf("%s" , s) ;
len = strlen(s) ;
suffixarray(len , 200) ;
for(int i = 0 ; i <= len ; i ++) pre[i] = i , siz[i] = 1 ;
for(int i = 2 ; i <= len ; i ++)
h[i].x = height[i] , h[i].id = i ;
sort(h + 2 , h + len + 1) ;
for(int i = 2 ; i <= len ; i ++)
{
if(h[i].x == 0) continue ;
ans += join(h[i].x , sa[h[i].id - 1] , sa[h[i].id]) ;
}
return ans ;
}
ll solve2()
{
int len = strlen(s) ;
ll ans = 0 ;
suffixarray(len , 200) ;
for(int i = 0 ; i <= len ; i ++) pre[i] = i , siz[i] = 1 ;
for(int i = 2 ; i <= len ; i ++)
h[i].x = height[i] , h[i].id = i ;
sort(h + 2 , h + len + 1) ;
for(int i = 2 ; i <= len ; i ++)
{
if(h[i].x == 0) continue ;
ans += join(h[i].x , sa[h[i].id - 1] , sa[h[i].id]) ;
}
return ans ;
}
int main()
{
ll ans = 0 ;
ans -= solve1() ;
int len1 = strlen(s) ;
for(int i = 0 ; i < len1 ; i ++) t[i] = s[i] ;
ans -= solve1() ;
int len2 = strlen(s) ;
s[len2] = '#' ;
for(int i = len2 + 1 ; i <= len2 + len1 ; i ++)
s[i] = t[i - len2 - 1] ;
s[len1 + len2 + 1] = 0 ;
ans += solve2() ;
printf("%lld\n" , ans) ;
return 0 ;
}