多项式全家桶学习笔记

有了 NTT,就有了多项式全家桶……

首先要感谢 @command_block 的文章《NTT 与多项式全家桶》以及 @Epsilon_Cube、@MoYuFang 和 @Diu 给予我的许多帮助。

因为板子在不断修 bug,所以代码最后统一放。

前置芝士

  • 多项式的各种运算是怎么定义的

由于我们只知道多项式加法和多项式乘法,但是这已经够了。所以所有的多项式运算都是用多项式加法和乘法定义的

  • 次数界

很多时候我们只对多项式 ff 的前 nn 项感兴趣(这时往往 ff 会有无限项),所以需要在 (modxn)\pmod{x^n} 的意义下运算。

由于多项式加法和多项式乘法的结果只会从低次项向高次项贡献,所以有:

F(x)modxn+G(x)modxnF(x)+G(x)(modxn)F(x)modxnG(x)modxnF(x)G(x)(modxn)\begin{aligned} F(x)\operatorname{mod}{x^n}+G(x)\operatorname{mod}{x^n}&\equiv F(x)+G(x)&\pmod{x^n}\\ F(x)\operatorname{mod}{x^n}\cdot G(x)\operatorname{mod}{x^n}&\equiv F(x)G(x)&\pmod{x^n}\\ \end{aligned}

即我们可以在有次数界的情况下定义所有多项式运算

一些记号

  • [xi]F(x)[x^i]F(x):多项式 F(x)F(x)ii 次项的系数,xix^i 的系数
  • FR(x)F_R(x)nn 次多项式的翻转 FR(x)=xnF(1x)F_R(x)=x^nF(\frac{1}{x})显然 FR(x)F_R(x) 的系数是 F(x)F(x) 的系数的翻转
  • F(n)(x)F^{(n)}(x)多项式 F(x)F(x)nn 阶导数,即对 F(x)F(x) 求导 nn 次的结果;

多项式求导和积分

定义多项式的求导:

F(x)=i=0n1ai+1(i+1)xiF^\prime(x)=\sum\limits_{i=0}^{n-1} a_{i+1}(i+1)x^{i}

定义多项式的积分(不定积分):

F(x)dx=C+i=1nai1xii\int F(x)\,dx=C+\sum\limits_{i=1}^{n} \dfrac{a_{i-1}x^{i}}{i}

同样的,多项式求导和积分也是互为逆操作

多项式牛顿迭代

这是一个比较重要的知识,有了它,就可以无脑推多项式各种操作的递推式了。

形式:已知函数 GG 满足 G(F(x))=0G(F(x))=0,求 F(x)modxnF(x)\operatorname{mod} x^n

实践中 GG 一般较为手动构造的简单函数。

结论:F(x)F(x)G(F(x))G(F(x))(modxn)F(x)\equiv F_*(x)-\dfrac{G(F_*(x))}{G'(F_*(x))}\pmod{x^n},其中 F(x)F(x)(modxn2)F_*(x)\equiv F(x)\pmod{x^{\frac{n}{2}}},注意 [x0]G(F(x))=0[x^0]G(F(x))=0 的解要单独求出。

和一般的牛迭十分相似,但是次数每次翻倍。证明如下:

假设目前已经求出了 F(x)F_*(x),考虑 G(F(x))G(F(x))F(x)F_*(x) 处的泰勒展开:

i=0G(i)(F(x))i!(F(x)F(x))i0(modxn)\sum\limits_{i=0}^{\infin} \frac{G^{(i)}(F_*(x))}{i!}(F(x)-F_*(x))^i\equiv0\pmod{x^n}

注意到 F(x)F(x)F(x)-F_*(x) 的最低系数非 00 项至少是 xn2x^{\frac{n}{2}},那么对于所有 i2i\ge2ii 都有 (F(x)F(x))i0(modxn)(F(x)-F_*(x))^i\equiv 0\pmod{x^n},所以:

G(F(x))+G(F(x))(F(x)F(x))0(modxn)F(x)F(x)G(F(x))G(F(x))(modxn)\begin{aligned} G(F_*(x))+G'(F_*(x))(F(x)-F_*(x))&\equiv0&\pmod{x^n}\\ F(x)&\equiv F_*(x)-\frac{G(F_*(x))}{G'(F_*(x))}&\pmod{x^n}\\ \end{aligned}

证毕。

多项式乘法逆

P4238 【模板】多项式乘法逆

假设已经求出了 B(x)F(x)1(modxn2)B_*(x)F(x)\equiv 1\pmod{x^{\frac{n}{2}}},现在要求 B(x)F(x)1(modxn)B(x)F(x)\equiv1\pmod{x^n},那么有:

G(B(x))=1B(x)F(x)0(modxn)\begin{aligned} G(B(x))=\frac{1}{B(x)}-F(x)\equiv 0\pmod{x^n} \end{aligned}

则可以直接套牛顿迭代:

B(x)B(x)G(B(x))G(B(x))(modxn)B(x)B(x)1B(x)F(x)1B2(x)(modxn)B(x)B(x)+B(x)B2F(x)(modxn)B(x)2B(x)B2F(x)(modxn)\begin{aligned} B(x)&\equiv B_*(x)-\frac{G(B_*(x))}{G'(B_*(x))}&\pmod{x^n}\\ B(x)&\equiv B_*(x)-\frac{\frac{1}{B_*(x)}-F(x)}{-\frac{1}{B_*^2(x)}}&\pmod{x^n}\\ B(x)&\equiv B_*(x)+B_*(x)-B_*^2F(x)&\pmod{x^n}\\ B(x)&\equiv 2B_*(x)-B_*^2F(x)&\pmod{x^n}\\ \end{aligned}

那么就可以做了,[x0]B(x)[x^0]B(x) 需要求一次乘法逆元,时间复杂度为 T(n)=T(n2)+O(nlogn)=O(nlogn)T(n)=T(\frac{n}{2})+O(n\log n)=O(n\log n)

不过还有个优化,注意到 NTT 的过程代入的是单位根,所以求的实际上是循环卷积:

F(x)G(x)=k=0mxki+jmodm=kfigjF(x)G(x)=\sum\limits_{k=0}^mx^k\sum\limits_{i+j\mod m=k} f_ig_j

观察倍增式子:

B(x)2B(x)B2(x)F(x)(modxn)B(x)\equiv 2B_*(x)-B_*^2(x)F(x)\pmod{x^n}

需要用到乘法的只有 B2(x)F(x)B_*^2(x)F(x)

先计算 B(x)F(x)B_*(x)F(x),它们的次数分别是 len2\dfrac{len}{2}lenlen

由于结果的第一项为 11,这个 11 后面的 len21\dfrac{len}{2}-1 项都为 00,所以长度为 lenlen 的循环卷积只会破坏前面的 1100

最后乘上一个 B(x)B_*(x) 即可,此时循环卷积只会破坏前 len2\dfrac{len}{2} 项。

多项式开根

P5205 【模板】多项式开根

P5277 【模板】多项式开根(加强版)

假设已经求出了 B2(x)F(x)(modxn2)B_*^2(x)\equiv F(x)\pmod{x^{\frac{n}{2}}},现在要求 B(x)2F(x)(modxn)B(x)^2\equiv F(x)\pmod{x^n},那么有:

G(B(x))=B2(x)F(x)0(modxn)G(B(x))=B^2(x)-F(x)\equiv 0\pmod{x^n}

直接套牛迭:

B(x)B(x)G(B(x))G(B(x))(mod2n)B(x)B(x)B2(x)F(x)2B(x)(mod2n)B(x)2B(x)2B2(x)+F(x)2B(x)(mod2n)B(x)B2(x)+F(x)2B(x)(mod2n)\begin{aligned} B(x)&\equiv B_*(x)-\frac{G(B_*(x))}{G'(B_*(x))}&\pmod{2^n}\\ B(x)&\equiv B_*(x)-\frac{B_*^2(x)-F(x)}{2B_*(x)}&\pmod{2^n}\\ B(x)&\equiv \frac{2B_*(x)^2-B_*^2(x)+F(x)}{2B_*(x)}&\pmod{2^n}\\ B(x)&\equiv \frac{B_*^2(x)+F(x)}{2B_*(x)}&\pmod{2^n}\\ \end{aligned}

最后 [x0]B(x)[x^0]B(x) 需要求一次二次剩余,可以用 BSGS/exBSGS 求单位根的高次同余方程来求解,再加上一个求逆,一个乘法,一个加法就做完了。

时间复杂度为 T(n)=T(n2)+O(nlogn)=O(nlogn)T(n)=T(\frac{n}{2})+O(n\log n)=O(n\log n)

多项式 ln\ln

P4725 【模板】多项式对数函数(多项式 ln)

ln(A(x))B(x)(modxn)\ln(A(x))\equiv B(x)\pmod{x^n}

两边同时求导,得:

ln(A(x))A(x)B(x)(modxn)\ln'(A(x))A'(x)\equiv B'(x)\pmod{x^n}

注意到 ln(x)=1x\ln'(x)=\dfrac{1}{x},所以:

A(x)A(x)B(x)(modxn)\dfrac{A'(x)}{A(x)}\equiv B'(x)\pmod{x^n}

再积分回来:

B(x)A(x)A(x)dx(modxn)B(x)\equiv \int \dfrac{A'(x)}{A(x)}dx\pmod{x^n}

所以一个求导,一个逆元,一个乘法,一个积分即可。

注意由于 [x0]A(x)=1[x^0]A(x)=1,所以有 [x0]B(x)=0[x^0]B(x)=0。并且若 [x0]A(x)=1[x^0]A(x)\not=1 则无法求 ln\ln 因为求不出模意义下的 ln([x0]A(x))\ln([x^0]A(x))

时间复杂度 O(nlogn)O(n\log n)

多项式 exp\exp

P4726 【模板】多项式指数函数(多项式 exp)

B(x)exp(A(x))(modxn)B(x)\equiv \exp(A(x))\pmod{x^n}

我们设 G(F(x))=ln(F(x))A(x)G(F(x))=\ln(F(x))-A(x),那么显然 G(B(x))0(modxn)G(B(x))\equiv0\pmod{x^n},可以使用牛顿迭代了。

回忆牛迭式子:F(x)F(x)G(F(x))G(F(x))(modxn)F(x)\equiv F_*(x)-\dfrac{G(F_*(x))}{G'(F_*(x))}\pmod{x^n}

显然,这里的 G(F(x))=1F(X)G'(F(x))=\dfrac{1}{F(X)},那么假设我们已经求出了 B(x)exp(A(x))(modxn2)B_*(x)\equiv \exp(A(x))\pmod{x^{\frac{n}{2}}},有:

B(x)B(x)G(B(x))B(x)(modxn)B(x)B(x)(ln(B(x))A(x))B(x)(modxn)B(x)(1ln(B(x))+A(x))B(x)(modxn)\begin{aligned} B(x)&\equiv B_*(x)-G(B_*(x))B_*(x)&\pmod{x^n}\\ B(x)&\equiv B_*(x)-(\ln(B_*(x))-A(x))B_*(x)&\pmod{x^n}\\ B(x)&\equiv (1-\ln(B_*(x))+A(x))B_*(x)&\pmod{x^n}\\ \end{aligned}

所以倍增求即可。

注意由于 [x0]A(x)=0[x^0]A(x)=0,所以有 [x0]B(x)=1[x^0]B(x)=1。并且若 [x0]A(x)=0[x^0]A(x)\not=0 则无法求 exp\exp 因为求不出模意义下的 exp([x0]A(x))\exp([x^0]A(x))

时间复杂度为 T(n)=T(n2)+O(nlogn)=O(nlogn)T(n)=T(\frac{n}{2})+O(n\log n)=O(n\log n)

多项式快速幂

P5245 【模板】多项式快速幂

观察到 (A(x))k=exp(kln(A(x)))(A(x))^k=\exp(k\ln(A(x))),所以一个 ln\ln,一个逐项乘法,一个 exp\exp 就做完了。

P5273 【模板】多项式幂函数(加强版)

这题和上一题的区别在于有 A0=1A_0\not=1 的情况,这时我们就没办法求 ln\lnexp\exp 了。

但是 A0=1A_0\not=1 没关系,我们可以让所有项都乘上 1A0\dfrac{1}{A_0},最后再都乘上 A0kA_0^k 即可。

遇到 A0=0A_0=0 的情况也没关系,把系数往前移,求出答案后再移回去即可。不过要注意原来前面 cntcnt00 在做幂运算后会变成 cnt×kcnt\times k00

多项式带余除法

P4512 【模板】多项式除法

发现余数很烦,所以我们想办法去掉它。

舍弃多项式的项的方法是一般是加上次数界,但注意到次数界只能舍弃高次,所以考虑把多项式的系数反过来搞

那么回到题目的式子:

F(x)=Q(x)G(x)+R(x)F(x)=Q(x)G(x)+R(x)

其中 FFnn 次多项式(已知),GGmm 次多项式(已知),QQnmn-m 次多项式(未知),RRm1m-1 次多项式(未知)。

换元,有:

F(1x)=Q(1x)G(1x)+R(1x)F(\frac{1}{x})=Q(\frac{1}{x})G(\frac{1}{x})+R(\frac{1}{x})

同乘 xnx^n,有:

xnF(1x)=xnQ(1x)G(1x)+xnR(1x)x^nF(\frac{1}{x})=x^nQ(\frac{1}{x})G(\frac{1}{x})+x^nR(\frac{1}{x})

发现 xnF(1x)=FR(x)x^nF(\frac{1}{x})=F_R(x)xnQ(1x)G(1x)=QR(x)GR(x)x^nQ(\frac{1}{x})G(\frac{1}{x})=Q_R(x)G_R(x)xnR(1x)=xnm+1RR(x)x^nR(\frac{1}{x})=x^{n-m+1}R_R(x),所以有:

FR(x)=QR(x)GR(x)+xnm+1RR(x)F_R(x)=Q_R(x)G_R(x)+x^{n-m+1}R_R(x)

那么我们机智地加上次数界,mod\operatorname{mod}xnm+1x^{n-m+1},就有:

FR(x)QR(x)GR(x)(modxnm+1)F_R(x)\equiv Q_R(x)G_R(x)\pmod{x^{n-m+1}}

那么就可以求出 QR(x)Q_R(x) 了,系数反过来就是 Q(x)Q(x),然后即可用乘法和减法求出 R(x)R(x),时间复杂度 O(nlogn)O(n\log n)

完整模板

包括多项式多点求值、多项式快速插值。

展开

const int p=998244353,ginv=332748118;

inline void add(int &x,int y)
{
	x+=y;
	if(x>=p) x-=p;
}
inline int gcd(int a,int b)
{
	int t=a%b;
	while(t!=0) a=b,b=t,t=a%b;
	return b;
}
inline int qpow(int x,int y)
{
	int res=1;
	for(;y>0;y>>=1) res=((y&1)?1ll*res*x%p:res),x=1ll*x*x%p;
	return res;
}
inline int exBSGS(int a,int b,int p)
{
	a%=p,b%=p;
	if(b==1||p==1) return 0;
	int cnt=0,val=1;
	while(1)
	{
		int d=gcd(a,p);
		if(d==1) break;
		if(b%d!=0) return -1;
		p/=d;
		b/=d;
		val=1ll*val*(a/d)%p;
		cnt++;
		if(val==b) return cnt;
	}
	map<int,int> mp;
	int val2=1,t=sqrt(p)+1;
	for(int B=1;B<=t;B++)
	{
		val2=1ll*val2*a%p;
		mp[1ll*b*val2%p]=B;
	}
	int cur=val;
	for(int A=1;A<=t;A++)
	{
		cur=1ll*cur*val2%p;
		if(mp.find(cur)!=mp.end()) return A*t-mp[cur]+cnt;
	}
	return -1;
}
inline int mosqrt(int x)
{
	int bse=exBSGS(3,x,p);
	if(bse==-1||(bse&1)) return -1;
	return qpow(3,bse/2);
}
namespace PLOY
{
	const int MS=5000005;

	typedef vector<int> ploy;
	
	inline ploy operator+(ploy a,ploy b);
	inline ploy operator+(ploy a,int b);
	inline ploy operator+(int a,ploy b);
	
	inline ploy operator-(ploy a,ploy b);
	inline ploy operator-(ploy a,int b);
	inline ploy operator-(int a,ploy b);
	
	inline ploy operator*(int b,ploy a);
	inline ploy operator*(ploy a,int b);
	inline ploy operator*(int b,ploy a);
	
	inline ploy inv(ploy a);
	inline ploy dao(ploy a);
	inline ploy jif(ploy a);
	inline ploy ln(ploy a);
	inline ploy exp(ploy a);
	inline ploy pow(ploy a,int b);
	inline ploy pow2(ploy a,int b,int b2); // b%p b2%(p-1)
	
	inline void divi(ploy a,ploy b,ploy &res,ploy &r);
	inline ploy operator%(ploy a,ploy b);
	
	inline vector<int> getval(ploy a,vector<int> x);
	inline ploy getploy(vector<int> x,vector<int> y);
	
	int p_rev[MS],p_rev_lstn;
	int p_tmpinv[MS];
	inline int getlen(int n)
	{
		int res=1;
		while(res<n) res<<=1;
		return res;
	}
	inline void NTT(ploy &a,int tpe)
	{
		int n=a.size();
		if(p_rev_lstn!=n)
		{
			p_rev_lstn=n;
			for(int i=0;i<n;i++) p_rev[i]=(p_rev[i>>1]>>1)|((i&1)?n>>1:0);
		}
		for(int i=0;i<n;i++) if(p_rev[i]<i) swap(a[p_rev[i]],a[i]);
		int g=tpe==1?3:ginv;
		for(int mid=1;mid<n;mid<<=1)
		{
			int len=mid<<1,Wn=qpow(g,(p-1)/len);
			for(int l=0;l<n-len+1;l+=len)
			{
				for(int k=0,Wk=1;k<mid;k++,Wk=1ll*Wk*Wn%p)
				{
					int x=a[l+k],y=1ll*Wk*a[l+mid+k]%p;
					a[l+k]=(x+y)%p,a[l+mid+k]=(x-y+p)%p;
				}
			}
		}
	}
	inline void DFT(ploy &a){NTT(a,1);}
	inline void IDFT(ploy &a)
	{
		int n=a.size();
		NTT(a,-1);
		int inv=qpow(n,p-2);
		for(int i=0;i<n;i++) a[i]=1ll*a[i]*inv%p;
	}
	inline ploy operator+(ploy a,ploy b)
	{
		if(a.size()<b.size()) swap(a,b);
		for(int i=0;i<b.size();i++) add(a[i],b[i]);
		return a;
	}
	inline ploy operator+(ploy a,int b){return add(a[0],b),a;}
	inline ploy operator+(int a,ploy b){return add(b[0],a),b;}
	inline ploy operator-(ploy a,ploy b)
	{
		if(a.size()<b.size()) a.resize(b.size(),0);
		for(int i=0;i<b.size();i++) add(a[i],p-b[i]);
		return a;
	}
	inline ploy operator-(ploy a,int b){return add(a[0],p-b),a;}
	inline ploy operator-(int a,ploy b)
	{
		add(b[0],p-a);
		for(int i=0;i<b.size();i++) b[i]=p-b[i];
		return b;
	}
	inline ploy operator*(ploy a,int b)
	{
		for(int i=0;i<a.size();i++) a[i]=1ll*a[i]*b%p;
		return a;
	}
	inline ploy operator*(int b,ploy a)
	{
		for(int i=0;i<a.size();i++) a[i]=1ll*a[i]*b%p;
		return a;
	}
	inline ploy operator*(ploy a,ploy b)
	{
		int n=a.size()+b.size()-1,m=getlen(n);
		a.resize(m,0),b.resize(m,0);
		DFT(a),DFT(b);
		for(int i=0;i<m;i++) a[i]=1ll*a[i]*b[i]%p;
		IDFT(a);
		a.resize(n,0);
		return a;
	}
	inline ploy inv(ploy a)
	{
		int n=a.size(),m=getlen(n);
	    ploy res={qpow(a[0],p-2)};
	    for(int len=2;len<=m;len<<=1)
	    {
	    	ploy tmp=a;
	    	tmp.resize(len,0);
	    	res=res*2-res*res*tmp;
	    	res.resize(len,0);
	    }
	    res.resize(n,0);
	    return res;
	}
	inline ploy sqrt(ploy a)
	{
		ploy res={mosqrt(a[0])};
	    if(res[0]==-1) return ploy();
	    int n=a.size(),m=getlen(n)*2; // 不知道为什么要乘二
	    for(int len=2;len<=m;len<<=1)
	    {
	    	ploy tmp=a;
	    	tmp.resize(len,0);
	    	res=(res*res+tmp)*inv(res*2);
	    	res.resize(len,0);
	    }
	    res.resize(n,0);
	    return res;
	}
	inline ploy dao(ploy a)
	{
		int n=a.size();ploy res=a;
		res[n-1]=0;for(int i=1;i<n;i++) res[i-1]=1ll*a[i]*i%p;
		return res;
	}
	inline ploy jif(ploy a)
	{
		int n=a.size();ploy res=a;
		for(int i=1;i<n;i++) if(p_tmpinv[i]==0) p_tmpinv[i]=(i==1?1:1ll*p_tmpinv[p%i]*(p-p/i)%p);
		res[0]=0;for(int i=1;i<n;i++) res[i]=1ll*a[i-1]*p_tmpinv[i]%p;
		return res;
	}
	inline ploy ln(ploy a)
	{
		int n=a.size();
		ploy res=jif(dao(a)*inv(a));
		res.resize(n,0);
		return res;
	}
	inline ploy exp(ploy a)
	{
		int n=a.size(),m=getlen(n)*2;
		ploy res={1};
		for(int len=2;len<=m;len<<=1)
		{
			ploy tmp=a;
			tmp.resize(len,0);
			res=(1-ln(res)+tmp)*res;
			res.resize(len,0);
		}
		res.resize(n,0);
		return res;
	}
	inline ploy pow(ploy a,int b)
	{
		ploy tmp=ln(a);
		for(int i=0;i<a.size();i++) tmp[i]=1ll*tmp[i]*b%p;
		return exp(tmp);
	}
	inline ploy pow2(ploy a,int b,int b2)
	{
		int n=a.size(),cnt=0;
		for(int i=0;i<n&&a[i]==0;i++) cnt++;
		if(1ll*cnt*b2>=n) return ploy(n,0);
		int pos=cnt*b2,m=n-pos;
		ploy tmp;
		for(int i=cnt;i<cnt+m;i++) tmp.push_back(a[i]);
		int inv=qpow(tmp[0],p-2),ml=qpow(tmp[0],b2);
		for(int i=0;i<m;i++) tmp[i]=1ll*tmp[i]*inv%p;
		tmp=pow(tmp,b);
		for(int i=0;i<m;i++) tmp[i]=1ll*tmp[i]*ml%p;
		ploy res(pos,0);
		for(int i=0;i<m;i++) res.push_back(tmp[i]);
		return res;
	}
	inline void divi(ploy a,ploy b,ploy &res,ploy &r)
	{
		int n=a.size(),m=b.size();
		if(n<m) return res={0},r=a,void();
		int rl=n-m+1;
		ploy ta=a,tb=b;
		reverse(ta.begin(),ta.end()),reverse(tb.begin(),tb.end());
		ta.resize(rl,0),tb.resize(rl,0);
		res=ta*inv(tb);
		res.resize(rl,0);
		reverse(res.begin(),res.end());
		r=a-b*res;
		r.resize(m-1,0);
	}
	inline ploy operator%(ploy a,ploy b)
	{
		ploy res,r;
		divi(a,b,res,r);
		return r;
	}
	inline vector<int> getval(ploy a,vector<int> x)
	{
		int n=x.size();
		vector<ploy> ml(n<<2|1),res(n<<2|1);
		vector<pair<pair<int,int>,pair<int,int> > > sta;
		vector<int> ans(n);
		sta.emplace_back(make_pair(1,0),make_pair(0,n-1));
		while(!sta.empty())
		{
			auto t=*sta.rbegin();
			sta.pop_back();
			int u=t.first.first,stp=t.first.second,l=t.second.first,r=t.second.second;
			if(l==r)
			{
				ml[u]={p-x[l],1};
				continue;
			}
			int mid=l+r>>1;
			if(stp==0)
			{
				t.first.second++;
				sta.push_back(t);
				sta.emplace_back(make_pair(u<<1,0),make_pair(l,mid));
			}
			else if(stp==1)
			{
				t.first.second++;
				sta.push_back(t);
				sta.emplace_back(make_pair(u<<1|1,0),make_pair(mid+1,r));
			}
			else ml[u]=ml[u<<1]*ml[u<<1|1];
		}
		res[1]=a%ml[1];
		sta.emplace_back(make_pair(1,0),make_pair(0,n-1));
		while(!sta.empty())
		{
			auto t=*sta.rbegin();
			sta.pop_back();
			int u=t.first.first,l=t.second.first,r=t.second.second;
			if(l==r)
			{
				ans[l]=res[u][0];
				continue;
			}
			int mid=l+r>>1;
			res[u<<1]=res[u]%ml[u<<1];
			res[u<<1|1]=res[u]%ml[u<<1|1];
			sta.emplace_back(make_pair(u<<1,0),make_pair(l,mid));
			sta.emplace_back(make_pair(u<<1|1,0),make_pair(mid+1,r));
		}
		return ans;
	}
	inline ploy getploy(vector<int> x,vector<int> y)
	{
		int n=x.size();
		vector<ploy> ml(n<<2|1),res(n<<2|1);
		vector<pair<pair<int,int>,pair<int,int> > > sta;
		vector<int> ans(n);
		sta.emplace_back(make_pair(1,0),make_pair(0,n-1));
		while(!sta.empty())
		{
			auto t=*sta.rbegin();
			sta.pop_back();
			int u=t.first.first,stp=t.first.second,l=t.second.first,r=t.second.second;
			if(l==r)
			{
				ml[u]={p-x[l],1};
				continue;
			}
			int mid=l+r>>1;
			if(stp==0)
			{
				t.first.second++;
				sta.push_back(t);
				sta.emplace_back(make_pair(u<<1,0),make_pair(l,mid));
			}
			else if(stp==1)
			{
				t.first.second++;
				sta.push_back(t);
				sta.emplace_back(make_pair(u<<1|1,0),make_pair(mid+1,r));
			}
			else ml[u]=ml[u<<1]*ml[u<<1|1];
		}
		ploy M=dao(ml[1]);
		vector<int> val=getval(M,x);
		for(int i=0;i<n;i++) val[i]=1ll*qpow(val[i],p-2)*y[i]%p;
		sta.emplace_back(make_pair(1,0),make_pair(0,n-1));
		while(!sta.empty())
		{
			auto t=*sta.rbegin();
			sta.pop_back();
			int u=t.first.first,stp=t.first.second,l=t.second.first,r=t.second.second;
			if(l==r)
			{
				res[u]={val[l]};
				continue;
			}
			int mid=l+r>>1;
			if(stp==0)
			{
				t.first.second++;
				sta.push_back(t);
				sta.emplace_back(make_pair(u<<1,0),make_pair(l,mid));
			}
			else if(stp==1)
			{
				t.first.second++;
				sta.push_back(t);
				sta.emplace_back(make_pair(u<<1|1,0),make_pair(mid+1,r));
			}
			else res[u]=res[u<<1]*ml[u<<1|1]+ml[u<<1]*res[u<<1|1];
		}
		return res[1];
	}
}

更快的板子

展开

#include <bits/stdc++.h>

using ull = unsigned long long;

const int N = 280000;
const int Mod = 998244353;

typedef std::vector<int> Poly;

namespace Pol {
	int pow(int a, int b, int ans = 1);
	int add(int a, int b) {
		return (a += b) >= Mod ? a -= Mod : a;
	}
	int sub(int a, int b) {
		return (a -= b) < 0 ? a += Mod : a;
	}
	void inc(int &a, int b) {
		(a += b) >= Mod ? a -= Mod : a;
	}
	void dec(int &a, int b) {
		(a -= b) < 0 ? a += Mod : a;
	}
	void init_Poly(int n = N);
	void DIT(int *A, int lim);
	void DIF(int *A, int lim);
	Poly inv(Poly A, int n);
	Poly mult(const Poly &A, int n, const Poly &B, int m);
	Poly operator*(const Poly &A, const Poly &B) {
		return mult(A, A.size(), B, B.size());
	}
	Poly Tmul(const Poly &A, int n, const Poly &B, int m);
	Poly getv(Poly A, int n, const std::vector<int> &f, int m);
	Poly drv(const Poly &A, int n);
	Poly itg(const Poly &A, int n);
	Poly ln(const Poly &A, int n);
	Poly exp(Poly A, int n);
	int fac[N], ifac[N], iv[N];
	Poly G[N << 1];
	ull tmp[N];
	int gw[N];
}  // namespace Pol

int main() {
	Pol::init_Poly();
	int n, m;
	scanf("%d %d", &n, &m);
	Poly F(n);
	for (int i = 0; i < n; ++i) scanf("%d", &F[i]);
	Poly G = Pol::ln(Pol::inv(Pol::exp(F, n), n), n);
	for (int i = 0; i < n; ++i) printf("%d%c", G[i], " \n"[i == n - 1]);
	std::vector<int> f(m);
	for (int i = 0; i < m; ++i) scanf("%d", &f[i]);
	G = Pol::getv(F, n, f, m);
	for (int i = 0; i < m; ++i) printf("%d%c", G[i], " \n"[i == m - 1]);
	return 0;
}

namespace Pol {
	void DIT(int *A, int lim) {
		for (int i = 0; i < lim; ++i) tmp[i] = A[i];
		for (int l = 1; l < lim; l <<= 1) {
			ull *k = tmp;
			for (int *g = gw; k < tmp + lim; k += (l << 1), ++g) {
				for (ull *x = k; x < k + l; ++x) {
					int o = x[l] % Mod;
					x[l] = 1ll * (*x + Mod - o) **g % Mod, *x += o;
				}
			}
		}
		int iv = pow(lim, Mod - 2);
		for (int i = 0; i < lim; ++i) A[i] = 1ll * tmp[i] % Mod * iv % Mod;
		std::reverse(A + 1, A + lim);
	}
	void DIF(int *A, int lim) {
		for (int i = 0; i < lim; ++i) tmp[i] = A[i];
		for (int l = lim / 2; l >= 1; l >>= 1) {
			ull *k = tmp;
			for (int *g = gw; k < tmp + lim; k += (l << 1), ++g) {
				for (ull *x = k; x < k + l; ++x) {
					int o = 1ll * x[l] **g % Mod;
					x[l] = *x + Mod - o, *x += o;
				}
			}
		}
		for (int i = 0; i < lim; ++i) A[i] = tmp[i] % Mod;
	}
	Poly mult(const Poly &A, int n, const Poly &B, int m) {
		if (n + m < 255) {
			Poly ans(n + m - 1);
			std::fill(tmp, tmp + n + m, 0);
			for (int i = 0; i < n; ++i)
				for (int j = 0; j < m; ++j) tmp[i + j] += 1ll * A[i] * B[j] % Mod;
			for (int i = 0; i < n + m - 1; ++i) ans[i] = tmp[i] % Mod;
			return ans;
		}
		int lim = 1;
		while (lim < (n + m - 1)) lim <<= 1;
		static int tA[N], tB[N];
		std::copy_n(A.begin(), n, tA), std::fill(tA + n, tA + lim, 0);
		std::copy_n(B.begin(), m, tB), std::fill(tB + m, tB + lim, 0);
		DIF(tA, lim), DIF(tB, lim);
		for (int i = 0; i < lim; ++i) tA[i] = 1ll * tA[i] * tB[i] % Mod;
		DIT(tA, lim);
		Poly ans(n + m - 1);
		std::copy_n(tA, n + m - 1, ans.begin());
		return ans;
	}
	Poly Tmul(const Poly &A, int n, const Poly &B, int m) {
		if (n + m < 255) {
			Poly ans(m - n + 1);
			std::fill(tmp, tmp + m - n + 2, 0);
			for (int i = 0; i < m; ++i)
				for (int j = i; j < n; ++j) tmp[j - i] += 1ll * B[i] * A[j] % Mod;
			for (int i = 0; i < m - n + 1; ++i) ans[i] = tmp[i] % Mod;
		}
		int lim = 1;
		while (lim < m) lim <<= 1;
		static int tA[N], tB[N];
		std::reverse_copy(A.begin(), A.begin() + n, tA), std::fill(tA + n, tA + lim, 0);
		std::copy_n(B.begin(), m, tB), std::fill(tB + m, tB + lim, 0);
		DIF(tA, lim), DIF(tB, lim);
		for (int i = 0; i < lim; ++i) tA[i] = 1ll * tA[i] * tB[i] % Mod;
		DIT(tA, lim);
		Poly ans(m - n + 1);
		std::copy_n(tA + n - 1, m - n + 1, ans.begin());
		return ans;
	}
	Poly inv(Poly A, int n) {
		int lim = 1;
		while (lim < (n << 1)) lim <<= 1;
		Poly F(lim), G(lim);
		A.resize(lim);
		G[0] = pow(A[0], Mod - 2);
		int now = 1;
		static int tA[N], tB[N];
		while (now < n) {
			std::copy_n(A.begin(), now << 1, F.begin());
			int lim = now << 2;
			std::copy_n(G.begin(), lim, tA);
			std::copy_n(F.begin(), lim, tB);
			DIF(tA, lim), DIF(tB, lim);
			for (int i = 0; i < lim; ++i) tA[i] = 1ll * sub(2, 1ll * tA[i] * tB[i] % Mod) * tA[i] % Mod;
			DIT(tA, lim);
			std::copy_n(tA, now << 1, G.begin());
			now <<= 1;
		}
		G.resize(n);
		return G;
	}
	Poly drv(const Poly &A, int n) {
		Poly ans(n - 1);
		for (int i = 0; i < n - 1; ++i) ans[i] = 1ll * A[i + 1] * (i + 1) % Mod;
		return ans;
	}
	Poly itg(const Poly &A, int n) {
		Poly ans(n + 1);
		for (int i = 0; i < n; ++i) ans[i + 1] = 1ll * A[i] * iv[i + 1] % Mod;
		return ans;
	}
	Poly ln(const Poly &A, int n) {
		Poly F = drv(A, n), G = inv(A, n);
		F = mult(F, n - 1, G, n);
		F.resize(n - 1);
		F = itg(F, n - 1);
		return F;
	}
	Poly exp(Poly A, int n) {
		int lim = 1;
		while (lim < (n << 1)) lim <<= 1;
		A.resize(lim);
		Poly L(lim);
		int now = 1;
		static int tF[N], tG[N], tL[N];
		std::fill(tG, tG + lim, 0), std::fill(tF, tF + lim, 0);
		tG[0] = 1;
		while (now < n) {
			int lim = now << 2;
			std::copy_n(tG, now, L.begin());
			L = ln(L, std::min(now << 1, n));
			L.resize(lim);
			std::copy_n(A.begin(), now << 1, tF);
			std::copy_n(L.begin(), lim, tL);
			DIF(tF, lim), DIF(tG, lim), DIF(tL, lim);
			for (int i = 0; i < lim; ++i) tG[i] = 1ll * tG[i] * sub(add(1, tF[i]), tL[i]) % Mod;
			DIT(tG, lim);
			std::fill(tG + (now << 1), tG + lim, 0);
			now <<= 1;
		}
		Poly G(n);
		std::copy_n(tG, n, G.begin());
		return G;
	}
	void getg(int x, int xl, int xr, const Poly &f, int m) {
		if (xl == xr) {
			G[x].resize(2);
			G[x][0] = 1;
			if (xl >= m)
				G[x][1] = 0;
			else
				G[x][1] = Mod - f[xl];
			return;
		}
		int xm = (xl + xr) >> 1;
		getg(x << 1, xl, xm, f, m), getg(x << 1 | 1, xm + 1, xr, f, m);
		G[x] = mult(G[x << 1], xm - xl + 2, G[x << 1 | 1], xr - xm + 1);
	}
	void getans(int x, int xl, int xr, Poly &ans, int m, const Poly &h) {
		if (xl >= m) return;
		if (xl == xr) return void(ans[xl] = h[0]);
		int xm = (xl + xr) >> 1;
		Poly hl = Tmul(G[x << 1 | 1], xr - xm + 1, h, xr - xl + 1);
		getans(x << 1, xl, xm, ans, m, hl);
		Poly hr = Tmul(G[x << 1], xm - xl + 2, h, xr - xl + 1);
		getans(x << 1 | 1, xm + 1, xr, ans, m, hr);
	}
	Poly getv(Poly A, int n, const std::vector<int> &f, int m) {
		n = std::max(n, m);
		A.resize(n);
		getg(1, 0, n - 1, f, m);
		Poly now = inv(G[1], n);
		std::reverse(now.begin(), now.begin() + n);
		Poly h = mult(now, n, A, n);
		for (int i = 0; i < n; ++i) h[i] = h[i + n - 1];
		h.resize(n);
		Poly ans(m);
		getans(1, 0, n - 1, ans, m, h);
		return ans;
	}
	void init_Poly(int n) {
		int t = 1;
		while ((1 << t) < n) ++t;
		t = std::min(t - 1, 21);
		gw[0] = 1, gw[1 << t] = pow(31, 1 << (21 - t));
		for (int i = t; i; --i) gw[1 << (i - 1)] = 1ll * gw[1 << i] * gw[1 << i] % Mod;
		for (int i = 1; i < (1 << t); ++i) gw[i] = 1ll * gw[i & (i - 1)] * gw[i & -i] % Mod;
		--n;
		fac[0] = 1;
		for (int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % Mod;
		ifac[n] = Pol::pow(fac[n], Mod - 2);
		for (int i = n - 1; i >= 0; --i) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % Mod;
		for (int i = 1; i <= n; ++i) iv[i] = 1ll * ifac[i] * fac[i - 1] % Mod;
	}
	int pow(int a, int b, int ans) {
		while (b) {
			if (b & 1) ans = 1ll * ans * a % Mod;
			a = 1ll * a * a % Mod;
			b >>= 1;
		}
		return ans;
	}
}  // namespace Pol