任意模数快速傅里叶变换(MTT)学习笔记

NTT 支持取模,没有精度问题,但是有模数限制。

FFT 支持取模,模数没有限制,但是有精度限制。

为了实现任意取模且几乎没有精度限制的多项式乘法,人们发明了 MTT。

模板题

基于 NTT 的实现

显然结果的实际值不超过 109×109×105=102310^9\times 10^9\times 10^5=10^{23},那么我们可以求出结果模三个满足 ABC1023ABC\ge 10^{23} 的质数 AABBCC 的值,然后用中国剩余定理合并答案

但是这种方法需要 99 次 NTT,常数极其巨大

基于 FFT 的实现

既然 FFT 有精度限制,那么我们就要想办法降低运算所需要的精度

不难想到,可以k=pk=\sqrt pF(x)=kA1(x)+B1(x)F(x)=kA_1(x)+B_1(x)G(x)=kA2(x)+B2(x)G(x)=kA_2(x)+B_2(x),那么有

FG=(kA1+B1)(kA2+B2)F*G=(kA_1+B_1)(kA_2+B_2)

=k2A1A2+k(A1B2+A2B1)+B1B2=k^2A_1A_2+k(A_1B_2+A_2B_1)+B_1B_2

DFT(A1),DFT(A2),DFT(B1),DFT(B2)\operatorname{DFT}(A_1),\operatorname{DFT}(A_2),\operatorname{DFT}(B_1),\operatorname{DFT}(B_2)IDFT(A1A2),IDFT(A1B2+A2B1),DFT(B1B2)\operatorname{IDFT}(A_1A_2),\operatorname{IDFT}(A_1B_2+A_2B_1),\operatorname{DFT}(B_1B_2) 即可,共需 77 次 FFT。

可以再快点吗?

当然可以!但是太过复杂不便于记忆,而且优化效果貌似不够明显,所以以后再补坑吧。

代码如下:

#include <iostream>
#include <cstdio>
#include <cmath>

using namespace std;

const long long MS=500005;
const long double PI=acos(-1);

struct plex
{
    long double x,y;
    plex(long double a=0,long double b=0) {x=a,y=b;}
};
plex operator+(plex a,plex b) {return plex(a.x+b.x,a.y+b.y);}
plex operator-(plex a,plex b) {return plex(a.x-b.x,a.y-b.y);}
plex operator*(plex a,plex b) {return plex(a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y);}

inline int getlen(int n)
{
    int res=1;
    while(res<n) res<<=1;
    return res;
}
int p_rev[MS],p_rev_lstn;
inline void FFT(int n,plex a[],int tpe)
{
    if(p_rev_lstn!=n)
    {
        p_rev_lstn=n;
        for(int i=0;i<n;i++) p_rev[i]=(p_rev[i>>1]>>1)|((i&1)?n>>1:0);
    }
    for(int i=0;i<n;i++) if(p_rev[i]<i) swap(a[p_rev[i]],a[i]);
    for(int mid=1;mid<n;mid<<=1)
    {
        int len=mid<<1;
        plex Wn=plex(cos(2*PI/len),tpe*sin(2*PI/len));
        for(int l=0;l<n-len+1;l+=len)
        {
            plex Wk=plex(1,0);
            for(int k=0;k<mid;k++,Wk=Wk*Wn)
            {
                plex x=a[l+k],y=Wk*a[l+mid+k];
                a[l+k]=x+y,a[l+mid+k]=x-y;
            }
        }
    }
}
inline void DFT(int n,plex a[]) {FFT(n,a,1);}
inline void IDFT(int n,plex a[])
{
    FFT(n,a,-1);
    for(int i=0;i<n;i++) a[i].x/=n,a[i].y/=n;
}
plex A1[MS],B1[MS],A2[MS],B2[MS];
plex p_mul_tmp1[MS],p_mul_tmp2[MS],p_mul_tmp3[MS];
inline void PMUL(int n,int m,int resn,int F[],int G[],int res[],int p)
{
	int len=getlen(n+m);
	int k=ceil(sqrt(p));
	for(int i=0;i<n;i++) A1[i].x=F[i]/k,B1[i].x=F[i]%k;
	for(int i=0;i<m;i++) A2[i].x=G[i]/k,B2[i].x=G[i]%k;
	DFT(len,A1),DFT(len,A2),DFT(len,B1),DFT(len,B2);
	for(int i=0;i<len;i++) p_mul_tmp1[i]=A1[i]*A2[i],p_mul_tmp2[i]=A1[i]*B2[i]+A2[i]*B1[i],p_mul_tmp3[i]=B1[i]*B2[i];
	IDFT(len,p_mul_tmp1),IDFT(len,p_mul_tmp2),IDFT(len,p_mul_tmp3);
	for(int i=0;i<resn;i++)
	{
		int x=(long long)(p_mul_tmp1[i].x+0.5)%p,y=(long long)(p_mul_tmp2[i].x+0.5)%p,z=(long long)(p_mul_tmp3[i].x+0.5)%p;
		res[i]=((1ll*k*k%p*x%p+1ll*k*y%p)%p+z)%p;
	}
	for(int i=0;i<len;i++) A1[i]=A2[i]=B1[i]=B2[i]=p_mul_tmp1[i]=p_mul_tmp2[i]=p_mul_tmp3[i]=plex();
}

int n,m,p;
int F[MS],G[MS],res[MS]; 

int main()
{
	scanf("%d%d%d",&n,&m,&p);
	n++;
	m++;
	for(int i=0;i<n;i++)
	{
		scanf("%d",&F[i]);
		F[i]=(F[i]%p+p)%p;
	}
	for(int i=0;i<m;i++)
	{
		scanf("%d",&G[i]);
		G[i]=(G[i]%p+p)%p;
	}
	PMUL(n,m,n+m-1,F,G,res,p);
	for(int i=0;i<n+m-1;i++)
	{
		printf("%d ",res[i]);
	}
	printf("\n");
	return 0;
}