快速沃尔什变换(FWT)学习笔记

Part 1 简介

FFT 本质上是处理加法卷积,即 AiBjA_iB_j 贡献到 Ci+jC_{i+j}。而 FWT 则是处理位运算卷积,即 AiBjA_iB_j 贡献到 CijC_{i\star j},其中 \star 是某种位运算。

FWT 的思想与 FFT 相近,也是创造一种数列到数列的线性变换 FWT(A)\operatorname{FWT}(A) 满足它与它的逆变换都可以快速计算且 FWT(AB)=FWT(A)FWT(B)\operatorname{FWT}(A\star B)=\operatorname{FWT}(A)\cdot\operatorname{FWT}(B),其中 \cdot 表示两个数列对应位置相乘。

有了这样的线性变换就可以先求出 FWT(A)\operatorname{FWT}(A)FWT(A)\operatorname{FWT}(A),算出 FWT(AB)=FWT(A)FWT(B)\operatorname{FWT}(A\star B)=\operatorname{FWT}(A)\cdot\operatorname{FWT}(B),最后求出 C=IFWT(AB)C=\operatorname{IFWT}(A\star B)

本文中假定 A=B=n|A|=|B|=nnn22 的整次幂。

Part 2 原理

先来解决 FWT(A)\operatorname{FWT}(A) 如何求解,不妨设 FWT(A)i=j=0n1ci,jAj\operatorname{FWT}(A)_i=\sum\limits_{j=0}^{n-1}c_{i,j}A_j,然后探讨 cc 需要满足的条件。那么有:

(FWT(A)FWT(B))i=j=0n1ci,jAjk=0n1ci,kBk=j=0n1k=0n1ci,jci,kAjBkFWT(AB)i=FWT(C)=j=0n1ci,jCj=j=0n1ci,j1k,ln,kl=jAkBl=j=0n1k=0n1ci,jkAjBk\begin{aligned} (\text{FWT}(A)\cdot \text{FWT}(B))_i&=\sum\limits_{j=0}^{n-1} c_{i,j}A_j\sum\limits_{k=0}^{n-1}c_{i,k}B_k\\ &=\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1} c_{i,j}c_{i,k}A_jB_k\\ \text{FWT}(A\star B)_i=\text{FWT}(C)&=\sum\limits_{j=0}^{n-1} c_{i,j}C_j\\ &=\sum\limits_{j=0}^{n-1}c_{i,j}\sum\limits_{1\le k,l\le n,k\star l=j} A_kB_l\\ &=\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1} c_{i,j\star k}A_jB_k \end{aligned}

所以有:

ci,jci,k=ci,jkc_{i,j}c_{i,k}=c_{i,j\star k}

注意到由于 \star 是位运算,所以不妨钦定 cc 也可以拆位处理。令 (a)2=a0a1a2(a)_2=\overline{a_0a_1a_2\dots}aa 的二进制表示,不妨钦定 cc 满足 ci,j=ci0,j0ci1,j1ci2,j2c_{i,j}=c_{i_0,j_0}c_{i_1,j_1}c_{i_2,j_2}\dots,这样做的好处是只要知道 c0/1,0/1c_{0/1,0/1} 就可以求出 ci,jc_{i,j},而且也有:

cil,jlcil,kl=cil,jlklci,jci,k=ci,jkc_{i_l,j_l}c_{i_l,k_l}=c_{i_l,j_l\star k_l}\Leftrightarrow c_{i,j}c_{i,k}=c_{i,j\star k}

但是暴力求 FWT(A)\text{FWT}(A)O(n2)O(n^2) 的,考虑优化,令 aa'aa 去掉二进制最高位的数,按位折半即有:

FWT(A)i=j=0n21ci0,j0ci,jAj+j=n2n1ci0,j0ci,jAj=ci0,0j=0n21ci,jAj+ci0,1j=n2n1ci,jAj\begin{aligned} \text{FWT}(A)_i&=\sum\limits_{j=0}^{\frac{n}{2}-1} c_{i_0,j_0}c_{i',j'}A_{j}+\sum\limits_{j=\frac{n}{2}}^{n-1} c_{i_0,j_0}c_{i',j'}A_{j}\\ &=c_{i_0,0}\sum\limits_{j=0}^{\frac{n}{2}-1} c_{i',j'}A_{j}+c_{i_0,1}\sum\limits_{j=\frac{n}{2}}^{n-1} c_{i',j'}A_{j} \end{aligned}

那么考虑 i0i_0 的取值,有:

FWT(A)i={c0,0j=0n21ci,jAj+c0,1j=n2n1ci,jAj0in21c1,0j=0n21ci,jAj+c1,1j=n2n1ci,jAjn2in1\text{FWT}(A)_i=\begin{cases} c_{0,0}\sum\limits_{j=0}^{\frac{n}{2}-1} c_{i',j'}A_{j}+c_{0,1}\sum\limits_{j=\frac{n}{2}}^{n-1} c_{i',j'}A_{j}&0\le i\le \frac{n}{2}-1\\ c_{1,0}\sum\limits_{j=0}^{\frac{n}{2}-1} c_{i',j'}A_{j}+c_{1,1}\sum\limits_{j=\frac{n}{2}}^{n-1} c_{i',j'}A_{j}&\frac{n}{2}\le i\le n-1 \end{cases}

A0A0AA 中下标二进制最高位为 00 的部分,A1A1 为最高位为 11的部分,那么有:

FWT(A)i={c0,0FWT(A0)i+c0,1FWT(A1)i0in21c1,0FWT(A0)in2+c1,1FWT(A1)in2n2in1\text{FWT}(A)_i=\begin{cases} c_{0,0}\text{FWT}(A0)_i+c_{0,1}\text{FWT}(A1)_i&0\le i\le \frac{n}{2}-1\\ c_{1,0}\text{FWT}(A0)_{i-\frac{n}{2}}+c_{1,1}\text{FWT}(A1)_{i-\frac{n}{2}}&\frac{n}{2}\le i\le n-1 \end{cases}

假设 n=2mn=2^m,则可以在 O(m2m)O(m2^m)O(nlogn)O(n\log n) 的时间复杂度内求解 FWT(A)\text{FWT(A)}

对于 IFWT(A)\text{IFWT}(A),只需要构造出 c0/1,0/1c_{0/1,0/1} 的逆即可。

Part 3 具体实现

根据 ci,jci,k=ci,jkc_{i,j}c_{i,k}=c_{i,j\star k}ci0,j0ci1,j1=c2i0+i1,2j0+j1c_{i_0,j_0}c_{i_1,j_1}=c_{2i_0+i_1,2j_0+j_1} 构造 c0/1,0/1c_{0/1,0/1} 即可,称其为位矩阵。

构造过程比较人类智慧,注意矩阵必须要有逆,即每一行和每一列都有至少一个位置不为 00 且不能有两行或者两列完全一样,否则就会有维度被丢失(线性代数说法)。

由于不同的位运算的 FWT\text{FWT} 本质相同,只是 cc 不同,所以不妨设 FWT(A,[c0,0c0,1c1,0c1,1])\text{FWT}\left(A,\begin{bmatrix}c_{0,0}&c_{0,1}\\c_{1,0}&c_{1,1}\end{bmatrix}\right)AA 在对应的 cc 意义下的 FWT\text{FWT} 结果,那么有 FWT(FWT(A,c),c1)=A\text{FWT}\left(\text{FWT}\left(A,c\right),c^{-1}\right)=A

3.1 OR\text{OR} 卷积

考虑构造满足 ci,jci,k=ci,jkc_{i,j}c_{i,k}=c_{i,j|k} 且存在逆的位矩阵。

c0,0c0,0=c0,00=c0,0c0,0{0,1}c_{0,0}c_{0,0}=c_{0,0|0}=c_{0,0}\Rightarrow c_{0,0}\in\{0,1\}

同理,c0/1,0/1{0,1}c_{0/1,0/1}\in\{0,1\}

由于 c0,0c0,1=c0,1c_{0,0}c_{0,1}=c_{0,1},所以 c0,0=1,c0,1=0c_{0,0}=1,c_{0,1}=0 或者 c0,0=1,c0,1=1c_{0,0}=1,c_{0,1}=1

同理,c1,0=1,c1,1=0c_{1,0}=1,c_{1,1}=0 或者 c1,0=1,c1,1=1c_{1,0}=1,c_{1,1}=1

那么位矩阵就有两种构造方式:

[1011][1110]\begin{bmatrix} 1&0\\1&1 \end{bmatrix} \begin{bmatrix} 1&1\\1&0 \end{bmatrix}

Tips:

观察这个位矩阵:

[1011]\begin{bmatrix} 1&0\\1&1 \end{bmatrix}

注意到它满足 ci,j=[i&j=j]c_{i,j}=[i\&j=j],也就是说这种情况下 FWT(A,c)\text{FWT}(A,c) 实际上相当于子集求和。

这启发我们形如 Bi=ij=iAjB_i=\sum\limits_{i\star j=i}A_jBi=ij=jAjB_i=\sum\limits_{i\star j=j}A_j 这样的和式(\star 是某种位运算)也可以用 FWT\text{FWT} 来快速求。

由于第一个位矩阵满足 ci,j=[i&j=j]c_{i,j}=[i\&j=j],所以下面采用第一个位矩阵,则设 c1=[xyzw]c^{-1}=\begin{bmatrix}x&y\\z&w\end{bmatrix},则有:

{x+0z=1y+0w=0x+z=0y+w=1\begin{cases} x+0z=1\\ y+0w=0\\ x+z=0\\ y+w=1 \end{cases}

解得:

{x=1y=0z=1w=1\begin{cases} x=1\\ y=0\\ z=-1\\ w=1 \end{cases}

所以 c1=[1011]c^{-1}=\begin{bmatrix}1&0\\-1&1\end{bmatrix}

3.2 AND\text{AND} 卷积

c0,0c0,0=c0,0&0=c0,0c0,0{0,1}c_{0,0}c_{0,0}=c_{0,0\&0}=c_{0,0}\Rightarrow c_{0,0}\in\{0,1\}

同理,c0/1,0/1{0,1}c_{0/1,0/1}\in\{0,1\}

由于 c0,0c0,1=c0,0c_{0,0}c_{0,1}=c_{0,0},所以 c0,0=0,c0,1=1c_{0,0}=0,c_{0,1}=1c0,0=1,c0,1=1c_{0,0}=1,c_{0,1}=1

同理,c1,0=0,c1,1=1c_{1,0}=0,c_{1,1}=1c1,0=1,c1,1=1c_{1,0}=1,c_{1,1}=1

那么位矩阵就有两种构造方式:

[0111][1101]\begin{bmatrix} 0&1\\1&1 \end{bmatrix} \begin{bmatrix} 1&1\\0&1 \end{bmatrix}

由于第一个位矩阵满足 ci,j=[ij=2k1]c_{i,j}=[i|j=2^k-1],所以采用第一个位矩阵,同理,待定系数法求逆得 c1=[1110]c^{-1}=\begin{bmatrix}-1&1\\1&0\end{bmatrix}

3.3 XOR\text{XOR} 卷积

由于对于任意的 x,yx,y,均有 c0,0cx,y=cx,yc_{0,0}c_{x,y}=c_{x,y},所以 c0,0=1c_{0,0}=1

根据 c1,1c1,1=c1,0c_{1,1}c_{1,1}=c_{1,0} 且矩阵不存在为 00 的行,所以 c1,0c_{1,0}c1,1c_{1,1} 均非 00​。

根据 c1,0c1,0=c1,0c_{1,0}c_{1,0}=c_{1,0}c1,0=0c_{1,0}\not=0 可得 c1,0=1c_{1,0}=1

根据,c0,1c0,1=c1,0c_{0,1}c_{0,1}=c_{1,0},可得 c0,1=1c_{0,1}=-1c0,1=1c_{0,1}=1

同理,c1,1c1,1=c1,0c_{1,1}c_{1,1}=c_{1,0}c1,1=1c_{1,1}=-1c1,1=1c_{1,1}=1

那么位矩阵就有两种构造方式:

[1111][1111]\begin{bmatrix} 1&-1\\1&1 \end{bmatrix} \begin{bmatrix} 1&1\\1&-1 \end{bmatrix}

同样的,由于第二个位矩阵满足 ci,j=(1)i&jc_{i,j}=(-1)^{|i\&j|}a|a|aa 二进制表示中 11 的个数),所以采用第二个位矩阵,求逆得 c1=[12121212]c^{-1}=\begin{bmatrix}\frac{1}{2}&\frac{1}{2}\\\frac{1}{2}&-\frac{1}{2}\end{bmatrix}

3.4 代码实现

直接套

FWT(A)i={c0,0FWT(A0)i+c0,1FWT(A1)i0in21c1,0FWT(A0)in2+c1,1FWT(A1)in2n2in1\text{FWT}(A)_i=\begin{cases} c_{0,0}\text{FWT}(A0)_i+c_{0,1}\text{FWT}(A1)_i&0\le i\le \frac{n}{2}-1\\ c_{1,0}\text{FWT}(A0)_{i-\frac{n}{2}}+c_{1,1}\text{FWT}(A1)_{i-\frac{n}{2}}&\frac{n}{2}\le i\le n-1 \end{cases}

即可,P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT) 代码如下:

#include <iostream>
#include <cstdio>

using namespace std;

const int p=998244353,inv2=499122177;

inline int getlen(int n)
{
	int res=1;
	while(res<n) res<<=1;
	return res;
}

const int ORC[2][2]={{1,0},{1,1}},IORC[2][2]={{1,0},{p-1,1}};
const int ANDC[2][2]={{0,1},{1,1}},IANDC[2][2]={{p-1,1},{1,0}};
const int XORC[2][2]={{1,1},{1,p-1}},IXORC[2][2]={{inv2,inv2},{inv2,p-inv2}};

inline void FWT(int n,int a[],const int c[2][2])
{
	for(int len=2;len<=n;len<<=1)
	{
		int mid=len>>1;
		for(int l=0;l<=n-len;l+=len)
		{
			for(int k=0;k<mid;k++)
			{
				int x=a[l+k],y=a[l+mid+k];
				a[l+k]=(1ll*c[0][0]*x%p+1ll*c[0][1]*y%p)%p;
				a[l+mid+k]=(1ll*c[1][0]*x%p+1ll*c[1][1]*y%p)%p;
			}
		}
	}
}

int n;
int a[1<<17],b[1<<17],c[1<<17];

int main()
{
	scanf("%d",&n);
	n=1<<n;
	for(int i=0;i<n;i++) scanf("%d",&a[i]);
	for(int i=0;i<n;i++) scanf("%d",&b[i]);
	FWT(n,a,ORC),FWT(n,b,ORC);
	for(int i=0;i<n;i++) c[i]=1ll*a[i]*b[i]%p;
	FWT(n,c,IORC);
	for(int i=0;i<n;i++) printf("%d ",c[i]);
	printf("\n");
	FWT(n,a,IORC),FWT(n,b,IORC);
	
	FWT(n,a,ANDC),FWT(n,b,ANDC);
	for(int i=0;i<n;i++) c[i]=1ll*a[i]*b[i]%p;
	FWT(n,c,IANDC);
	for(int i=0;i<n;i++) printf("%d ",c[i]);
	printf("\n");
	FWT(n,a,IANDC),FWT(n,b,IANDC);
	
	FWT(n,a,XORC),FWT(n,b,XORC);
	for(int i=0;i<n;i++) c[i]=1ll*a[i]*b[i]%p;
	FWT(n,c,IXORC);
	for(int i=0;i<n;i++) printf("%d ",c[i]);
	printf("\n");
	return 0;
}

Part 4 更多拓展

有些时候考的往往不是裸的 FWT\text{FWT}

下文中若 AABB 为数列,\star 为某种位运算,那么 ABA\star B 表示 AABB\star 运算下的卷积结果,即 (AB)i=jk=iAjBk(A\star B)_i=\sum\limits_{j\star k=i}A_jB_k

FWT\text{FWT} 应用时往往要利用它是线性变换来优化,即 FWT(A)+FWT(B)=FWT(A+B)\text{FWT}(A)+\text{FWT}(B)=\text{FWT(A+B)}FWT(aA)=aFWT(A)\text{FWT}(aA)=a\text{FWT}(A)

AA 只有少数项非 00 则可能有分类讨论优化时间复杂度的做法。

一些例题:

4.1 离线子集卷积

Ck=ij=k,i&j=0AiBj=ikAiBkiC_{k}=\sum\limits_{i|j=k,i\&j=0}A_iB_j=\sum\limits_{i\subseteq k}A_iB_{k-i}

发现 i&j=0i\&j=0 很烦,但是不难发现它等价于 i+j=k|i|+|j|=|k|a|a| 表示 aa 二进制表示中的 11 的个数),所以可以令 SAi,j=[j=i]Aj,SBi,j=[j=i]BjSA_{i,j}=[|j|=i]A_j,SB_{i,j}=[|j|=i]B_j,那么有:

Ri=j=0iSAjSBijR_i=\sum\limits_{j=0}^iSA_j|SB_{i-j}

由于 FWT\text{FWT} 是线性变换,所以有:

Ri=IFWT(j=0iFWT(SAj)FWT(SBij))R_i=\text{IFWT}\left(\sum\limits_{j=0}^i\text{FWT}(SA_j)\cdot \text{FWT}(SB_{i-j})\right)

答案即为 Ri,iR_{|i|,i},时间复杂度 O(m22m)O(m^22^m),参考代码:(P6097 【模板】子集卷积

int main()
{
	scanf("%d",&n);
	n=1<<n;
	for(int i=0;i<n;i++) scanf("%d",&a[i]);
	for(int i=0;i<n;i++) scanf("%d",&b[i]);
	for(int i=0;i<(1<<20);i++) for(int j=0;j<20;j++) popc[i]+=i>>j&1;
	for(int i=0;i<=20;i++) for(int j=0;j<n;j++) sa[i][j]=(popc[j]==i)*a[j],sb[i][j]=(popc[j]==i)*b[j];
	for(int i=0;i<=20;i++) FWT(n,sa[i],ORC),FWT(n,sb[i],ORC);
	for(int i=0;i<=20;i++) for(int j=0;j<=i;j++) for(int k=0;k<(1<<20);k++) r[i][k]=(r[i][k]+1ll*sa[j][k]*sb[i-j][k]%p)%p;
	for(int i=0;i<=20;i++) FWT(n,r[i],IORC);
	for(int i=0;i<n;i++) printf("%d ",r[popc[i]][i]);
	printf("\n");
	return 0;
}

4.2 半在线子集卷积

Ck=Bkij=k,i&j=0,i=kCiAj=BkikCiAkiC_k=B_k\sum\limits_{i|j=k,i\&j=0,i\not=k}C_iA_j=B_k\sum\limits_{i\subset k}C_iA_{k-i}

和离线子集卷积类似,令 SAi,j=[j=i]Aj,SCi,j=[j=i]CjSA_{i,j}=[|j|=i]A_j,SC_{i,j}=[|j|=i]C_j,那么有:

SCi=Bk=0i1SCkSAik=BIFWT(j=0iFWT(SCj)FWT(SAij))\begin{aligned} SC_{i}&=B\cdot\sum\limits_{k=0}^{i-1}SC_k|SA_{i-k}\\ &=B\cdot\text{IFWT}\left(\sum\limits_{j=0}^i\text{FWT}(SC_j)\cdot \text{FWT}(SA_{i-j})\right) \end{aligned}

那么从小到大枚举 ii 计算即可。

4.3 每一位运算法则不同

给定一个长 logn\log n 的字符串,字符集为 |&^,表示每一位要进行的位运算。

依旧是考虑:

FWT(A)i={c0,0FWT(A0)i+c0,1FWT(A1)i0in21c1,0FWT(A0)in2+c1,1FWT(A1)in2n2in1\text{FWT}(A)_i=\begin{cases} c_{0,0}\text{FWT}(A0)_i+c_{0,1}\text{FWT}(A1)_i&0\le i\le \frac{n}{2}-1\\ c_{1,0}\text{FWT}(A0)_{i-\frac{n}{2}}+c_{1,1}\text{FWT}(A1)_{i-\frac{n}{2}}&\frac{n}{2}\le i\le n-1 \end{cases}

只不过此时 cc 取字符串第 logn\log n 位对应运算的那个矩阵。

inline void FWT(int n,int a[],int w)
{
	for(int len=2,pos=1;len<=n;len<<=1,pos++)
	{
		int mid=len>>1;
		int c[2][2];
		if(w==1)
		{
			if(str[pos]=='|') memcpy(c,ORC,sizeof(ORC));
			if(str[pos]=='&') memcpy(c,ANDC,sizeof(ANDC));
			if(str[pos]=='^') memcpy(c,XORC,sizeof(XORC));
		}
		else
		{
			if(str[pos]=='|') memcpy(c,IORC,sizeof(IORC));
			if(str[pos]=='&') memcpy(c,IANDC,sizeof(IANDC));
			if(str[pos]=='^') memcpy(c,IXORC,sizeof(IXORC));
		}
		for(int l=0;l<=n-len;l+=len)
		{
			for(int k=0;k<mid;k++)
			{
				int x=a[l+k],y=a[l+mid+k];
				a[l+k]=(1ll*c[0][0]*x%p+1ll*c[0][1]*y%p)%p;
				a[l+mid+k]=(1ll*c[1][0]*x%p+1ll*c[1][1]*y%p)%p;
			}
		}
	}
}