FWT的簡介
一般FWT用來解決一下問題:
- Ck=∑i∣j=kAiBj
- Ck=∑i&j=kAiBj
- Ck=∑i⊕j=kAiBj
實現的大概思路就是就是先把他們轉化成fwt(A)(類似FFT的點值表達),然後對應爲相乘,最後在還原爲多項式(整個過程很類似與快速傅里葉變換)
or 卷積
現在要做到這個的快速卷積:Ck=∑i∣j=kAiBj
定義A∣B爲多項式的or卷積,顯然 A∣B=B∣A(交換律),(A+B)∣C=A∣C+B∣C(結合律)
定義fwt(A)[k]=∑i∣kAi,因爲我們要讓fwt(C)=fwt(A)×fwt(B),這個基於的原理就是若i∣k=k,j∣k=k,就能推出(i∣j)∣k=k,那麼原來是k子集的或之後還是k子集。
現在我們來研究如何計算fwt(A),類似FFT地用分治的方法,設A0爲當前位爲0的項,A1就是剩下部分:
fwt(A)={(fwt(A0),fwt(A0+A1))An>0n=0根據fwt(A)的定義可以知道:fwt(A+B)=fwt(A)+fwt(B),根據定義也可以知道如果最高位爲1,那麼就把fwt(A0+A1)算作A的後半部分就可以了,爲0的話子集就是0
現在我們來證明一下fwt(A∣B)=fwt(A)×fwt(B):
fwt(A∣B)=fwt((A∣B)0,(A∣B)1)=fwt(A0∣B0,A0∣B1+A1∣B0+A1∣B1)=(fwt(A0∣B0),fwt(A0∣B0+A1∣B0+A0∣B1+A1B1))=(fwt(A0)×fwt(B0),fwt(A0+A1)×fwt(B0+B1))=(fwt(A0),fwt(A0+A1))×(fwt(B0),fwt(B0,B1))=fwt(A)×fwt(B)這裏用到了數學歸納法,首先n=0的情況肯定成立,然後我們假設較小的規模成立,以此推導更大的規模(這裏是21)
最後返回來的dfwt就是根據上面的fwt設計的,變換如下:
dfwt(A)={dfwt(A0),dfwt(A1−A0)An>1n=0
and 卷積
這個和上面的卷積極其類似,直接給出結論:
fwt(A)={(fwt(A0+A1),fwt(A0))An>0n=0逆變換如下:
dfwt(A)={dfwt(A0−A1),dfwt(A1)An>0n=0
xor 卷積
這就是重頭戲了,我們要解決:Ck=∑i⊕j=kAiBj
先定義fwt(A)[x]=∑2∣d(x∩i)Ai−∑2∣[d(x∩i)−1]Ai,d是二進制1的個數,這樣定義是基於一個結論:
d(x∩(i⊕j))=d(x∩i)+d(x∩j)−2d(x∩i∩j)考慮每一位的合法性,就可以推知所有情況,這個自己枚舉一下可能的組合就行了,然後:
fwt(A)[x]⊕fwt(B)[x]=2∣d(x∩i)∑2∣d(x∩j)∑AiBj−2∣[d(x∩i)−1]∑2∣d(x∩j)∑AiBj......你可以發現如果d(x∩i)和d(x∩j)奇偶性相同的話那麼前面的符號是正,否則前面的符號是負,我們觀察上面的結論,發現d(x∩(i⊕j))和d(x∩i)+d(x∩j)的奇偶性相同,恰好就對應了。
現在來考慮正變換,基本思路還是考慮最高位,左半邊的話都選自己和選自己再選對半邊都可以,而且不需要變號,因爲不會產生新的1位,兩個都選右半邊的話就需要變號,那麼就這樣算:
fwt(A)={(fwt(A0+A1),fwt(A0−A1))An>0n=0然後逆變化就是:ifwt(A)={(2ifwt(A0+A1),2ifwt(A0−A1))An>0n=0
然後貼一個板題的代碼:
#include <cstdio>
const int M = 200005;
const int MOD = 998244353;
int read()
{
int num=0,flag=1;char c;
while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
while(c>='0'&&c<='9')num=(num<<3)+(num<<1)+(c^48),c=getchar();
return num*flag;
}
int n,a[M],b[M],A[M],B[M],inv2=(MOD+1)/2;
void fwt_or(int *a,int op)
{
for(int i=1;i<n;i<<=1)
for(int p=i<<1,j=0;j<n;j+=p)
for(int k=0;k<i;k++)
{
if(op==1) a[i+j+k]=(a[i+j+k]+a[j+k])%MOD;
else a[i+j+k]=(a[i+j+k]-a[j+k]+MOD)%MOD;
}
}
void fwt_and(int *a,int op)
{
for(int i=1;i<n;i<<=1)
for(int p=i<<1,j=0;j<n;j+=p)
for(int k=0;k<i;k++)
{
if(op==1) a[j+k]=(a[j+k]+a[i+j+k])%MOD;
else a[j+k]=(a[j+k]-a[i+j+k]+MOD)%MOD;
}
}
void fwt_xor(int *a,int op)
{
for(int i=1;i<n;i<<=1)
for(int p=i<<1,j=0;j<n;j+=p)
for(int k=0;k<i;k++)
{
int x=a[j+k],y=a[i+j+k];
a[j+k]=(x+y)%MOD;
a[i+j+k]=(x+MOD-y)%MOD;
if(op==-1)
{
a[j+k]=1ll*a[j+k]*inv2%MOD;
a[i+j+k]=1ll*a[i+j+k]*inv2%MOD;
}
}
}
void init()
{
for(int i=0;i<n;i++)
A[i]=a[i],B[i]=b[i];
}
signed main()
{
n=1<<read();
for(int i=0;i<n;i++) a[i]=read();
for(int i=0;i<n;i++) b[i]=read();
init();
fwt_or(A,1);fwt_or(B,1);
for(int i=0;i<n;i++) A[i]=1ll*A[i]*B[i]%MOD;
fwt_or(A,-1);
for(int i=0;i<n;i++) printf("%d ",A[i]);
puts("");
init();
fwt_and(A,1);fwt_and(B,1);
for(int i=0;i<n;i++) A[i]=1ll*A[i]*B[i]%MOD;
fwt_and(A,-1);
for(int i=0;i<n;i++) printf("%d ",A[i]);
puts("");
init();
fwt_xor(A,1);fwt_xor(B,1);
for(int i=0;i<n;i++) A[i]=1ll*A[i]*B[i]%MOD;
fwt_xor(A,-1);
for(int i=0;i<n;i++) printf("%d ",A[i]);
puts("");
}