P3181 後綴數組 + 容斥 + 並查集

題目傳送門

題意:

給你兩個字符串 \dpi{150}s,t ,長度分別是 n1,n2 。

讓你從兩個字符串中各自挑出一個子串,問這兩個子串相同的方案數。

數據範圍: 1 \leqslant n1 , n2 \leqslant 2 \cdot 10^5 。

題解:

把兩個字符串拼接起來,中間需要連接一個沒有用過的字符,記作#。

連接後的字符串記作 p 。

設 f(x) 表示 字符串 x 中挑出兩個子串相同的方案數。

這道題答案是 f(p) - f(s) - f(t) 。容斥想一想就好了。

考慮怎麼求 f(x) 。

計數題就要考慮倒序height + 並查集。這算是套路吧。

我們倒序排列 height 後。

height[i] 連接 sa[i-1],sa[i] 。

在合併時,我們累加 height[i] * siz[sa[i-1]] * siz[sa[i]] 就好了。

感受:

遇到計數題要考慮一下倒序height + 並查集。

這個思路真好用。

代碼:

#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
const int maxn = 4e5 + 5 ;
int rk[maxn << 1] , sa[maxn << 1] , height[maxn << 1] ;
int tmp[maxn << 1] , cnt[maxn] ;
char s[maxn] , t[maxn] ;
int pre[maxn] , siz[maxn] ;
struct node
{
	int x , id ;
	bool operator < (const node &s) const
	{
		if(x != s.x)  return x > s.x ;
		else  return id < s.id ;
	}
} h[maxn] ;
void suffixarray(int n , int m)
{
   n ++ ;
   for(int i = 0 ; i < n * 2 + 5 ; i ++)
     rk[i] = sa[i] = height[i] = tmp[i] = 0 ;//開2 倍空間
   for(int i = 0 ; i < m ; i ++)  cnt[i] = 0 ;
   for(int i = 0 ; i < n ; i ++)  cnt[rk[i] = s[i]] ++ ;
   for(int i = 1 ; i < m ; i ++)  cnt[i] += cnt[i - 1] ;
   for(int i = 0 ; i < n ; i ++)  sa[-- cnt[rk[i]]] = i ;
   for(int k = 1 ; k <= n ; k <<= 1)
   {
     int j = 0 ;
     for(int i = 0 ; i < n ; i ++)
     {
       j = sa[i] - k ;
       if(j < 0)  j += n ;
       tmp[cnt[rk[j]] ++] = j ;
     }
     sa[tmp[cnt[0] = 0]] = j = 0 ;
     for(int i = 1 ; i < n ; i ++)
     {
       if(rk[tmp[i]] != rk[tmp[i - 1]]
       || rk[tmp[i] + k] != rk[tmp[i - 1] + k])
         cnt[++ j] = i ;
       sa[tmp[i]] = j ;
     }
     memcpy(rk , sa , n * sizeof(int)) ;
     memcpy(sa , tmp , n * sizeof(int)) ;
     if(j >= n - 1)  break ;
   }
   height[0] = 0 ;
   for(int i = 0 , k = 0 , j = rk[0] ; i < n - 1 ; i ++ , k ++)
     while(~k && s[i] != s[sa[j - 1] + k])
       height[j] = k -- , j = rk[sa[j] + 1] ;
}
int find(int u) 
{
   if(pre[u] == u)  return u ;
   return pre[u] = find(pre[u]) ;
}
ll join(int k , int x , int y)
{
   ll ans = 0 ;
   int fx = find(x) ;
   int fy = find(y) ;
   if(fx != fy)  
   {
   	  ans += ll(k) * siz[fx] * siz[fy] ;
   	  pre[fx] = fy ;
   	  siz[fy] += siz[fx] ;
   }  
   //cout << k << ' ' << x << ' ' << y << ' ' << ans << '\n' ; 
   return ans ;
}
ll solve1()
{
	int len ;
	ll ans = 0 ;
	scanf("%s" , s) ;
	len = strlen(s) ;
	suffixarray(len , 200) ;
	for(int i = 0 ; i <= len ; i ++)  pre[i] = i , siz[i] = 1 ;
	for(int i = 2 ; i <= len ; i ++)
	  h[i].x = height[i] , h[i].id = i ;
	sort(h + 2 , h + len + 1) ;
	for(int i = 2 ; i <= len ; i ++)
	{
		if(h[i].x == 0)  continue ;
		ans += join(h[i].x , sa[h[i].id - 1] , sa[h[i].id]) ;
	}
	return ans ;
}
ll solve2()
{
	int len = strlen(s) ;
	ll ans = 0 ;
	suffixarray(len , 200) ;
	for(int i = 0 ; i <= len ; i ++)  pre[i] = i , siz[i] = 1 ;
	for(int i = 2 ; i <= len ; i ++)
	  h[i].x = height[i] , h[i].id = i ;
	sort(h + 2 , h + len + 1) ;
	for(int i = 2 ; i <= len ; i ++)
	{
		if(h[i].x == 0)  continue ;
		ans += join(h[i].x , sa[h[i].id - 1] , sa[h[i].id]) ;
	}	  
	return ans ;
}
int main()
{
	ll ans = 0 ;
	ans -= solve1() ;
	int len1 = strlen(s) ;
	for(int i = 0 ; i < len1 ; i ++)  t[i] = s[i] ;
	ans -= solve1() ;
	int len2 = strlen(s) ;
	s[len2] = '#' ;
	for(int i = len2 + 1 ; i <= len2 + len1 ; i ++)  
	  s[i] = t[i - len2 - 1] ;
	s[len1 + len2 + 1] = 0 ;
	ans += solve2() ;
	printf("%lld\n" , ans) ;
	return 0 ;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章