快速数论变换(NTT)学习笔记

FFT 虽然奇妙,但是由于需要用浮点数,所以有精度问题。NTT 就是取模时 FFT 的一个完美替代品。

FFT 有精度问题的原因显然是涉及了本原单位根,它要用三角函数求,那么我们可以考虑找一个在模素数时的“本原单位根”

原根

众所周知,在模 pp 意义下只有 p1p-1 个能代入多项式的值(00 没用),所以我们只要找到一个能“遍历”这 p1p-1 个值的数,就能构成一个“单位圆”,也就能找到“本原单位根”了。

这个能构成“单位圆”的数就被称为原根 gg

形式化的,gg 为模 pp 的原根,当且仅当在模 pp 意义下 g0=g1=g2=...=gp2g^0\not=g^1\not=g^2\not=...\not=g^{p-2}

也就是说在模世界的幂运算这张纸上,原根 gg 是一个圆。

容易发现 gg 就是在模 pp 意义下的 p1p-1 次本原单位根。

可以证明,在模素数 pp 的意义下,gg 总是存在的。所以 NTT 的模数必须要是素数

我们考虑模 pp 意义下的 nn 次本原单位根 ωn\omega_n由于 gg 的次幂构成了“可以被分为 p1p-1 等分的单位圆”,所以 ωn=gp1n\omega_n=g^{\frac{p-1}{n}}

显然,pp 意义下的 ωn\omega_n 存在当且仅当 np1n\mid p-1

NTT

如果 n=2kn=2^k,那么我们可以把 FFT 中的虚数本原单位根都替换成模素数 pp 意义下的本原单位根,来实现 NTT。

但因为必须满足 2k=np12^k=n\mid p-1,所以 NTT 对模数有特殊限制

NTT 可用的模数 pp 需要满足:

  • pp 是个质数

  • p=a2b+1p=a\cdot2^b+1

例如 998244353998244353 就是合法的,因为 998244353=717223+1998244353=7\cdot17\cdot2^{23}+1。它的一个原根是 g=3g=3

还有 167772161167772161104857601104857601 也是合法的,g=3g=3 都是它们的原根。

模板题代码如下:

#include <iostream>
#include <cstdio>

using namespace std;

const long long MS=5000005;
const int p=998244353,ginv=332748118;

inline int qpow(int x,int y)
{
    int res=1;
    for(;y>0;y>>=1) res=((y&1)?1ll*res*x%p:res),x=1ll*x*x%p;
    return res;
}
inline int getlen(int n)
{
    int res=1;
    while(res<n) res<<=1;
    return res;
}
int p_rev[MS],p_rev_lstn;
inline void NTT(int n,int a[],int tpe)
{
    if(p_rev_lstn!=n)
    {
        p_rev_lstn=n;
        for(int i=0;i<n;i++) p_rev[i]=(p_rev[i>>1]>>1)|((i&1)?n>>1:0);
    }
    for(int i=0;i<n;i++) if(p_rev[i]<i) swap(a[p_rev[i]],a[i]);
    int g=tpe==1?3:ginv;
    for(int mid=1;mid<n;mid<<=1)
    {
        int len=mid<<1,Wn=qpow(g,(p-1)/len);
        for(int l=0;l<n-len+1;l+=len)
        {
            for(int k=0,Wk=1;k<mid;k++,Wk=1ll*Wk*Wn%p)
            {
                int x=a[l+k],y=1ll*Wk*a[l+mid+k]%p;
                a[l+k]=(x+y)%p,a[l+mid+k]=(x-y+p)%p;
            }
        }
    }
}
inline void DFT(int n,int a[]){NTT(n,a,1);}
inline void IDFT(int n,int a[])
{
    NTT(n,a,-1);
    int inv=qpow(n,p-2);
    for(int i=0;i<n;i++) a[i]=1ll*a[i]*inv%p;
}

int n,m,a[MS],b[MS];

int main()
{
	scanf("%d%d",&n,&m);
	for(int i=0;i<=n;i++)
	{
		scanf("%d",&a[i]);
	}
	for(int i=0;i<=m;i++)
	{
		scanf("%d",&b[i]);
	}
	int len=getlen(n+m+1);
	DFT(len,a);
	DFT(len,b);
	for(int i=0;i<len;i++)
	{
		a[i]=1ll*a[i]*b[i]%p;
	}
	IDFT(len,a);
	for(int i=0;i<n+m+1;i++)
	{
		printf("%d ",a[i]);
	}
	printf("\n");
	return 0;
}