ABC236Ex Distinct Multiples 做题记录

给定 n,mn,m 和一个长 nn 的序列 aia_i,求满足以下条件的长 nn 的序列 bb 的个数

  • 1bim1\le b_i\le m
  • bib_i 两两不同;
  • bib_i 可以被 aia_i 整除;

998244353998244353 取模。

1n161\le n\le 161aim10181\le a_i\le m\le 10^{18}

bi=bjb_i=b_j 则连接无向边 (i,j)(i,j),则两两不同的限制相当于图中没有边。

考虑容斥,设 dpSdp_SiSi\in Sbib_i 的方案数。钦定某些点在同一个连通块,不难发现相同大小的连通块的容斥系数是相同的。不妨设大小为 nn 的连通块的容斥系数为 fnf_n,则有转移:

dpS=TS,min{T}=min{S}fTmlcm{aiiT}dpSTdp_S=\sum\limits_{T\subseteq S,\min\{T\}=\min\{S\}}f_{|T|}\left\lfloor\frac{m}{\text{lcm}\{a_i|i\in T\}}\right\rfloor dp_{S-T}

考虑 fnf_n 需要满足的条件,设 gng_n 表示所有 nn 个点的无向图的容斥系数的总和,那么枚举图中边的个数,有:

gn=i=0n(n1)2(n(n1)2i)(1)i=i=0n(n1)2(n(n1)2i)(1)i1ni=[n=0]+[n=1]\begin{aligned} g_n&=\sum\limits_{i=0}^{\frac{n(n-1)}{2}}\binom{\frac{n(n-1)}{2}}{i}(-1)^i\\ &=\sum\limits_{i=0}^{\frac{n(n-1)}{2}}\binom{\frac{n(n-1)}{2}}{i}(-1)^i1^{n-i}\\ &=[n=0]+[n=1] \end{aligned}

而枚举 nn 所在的连通块大小,有:

gn=i=1n(n1i1)figni=fn+(n1)fn1\begin{aligned} g_n&=\sum\limits_{i=1}^{n}\binom{n-1}{i-1}f_{i}g_{n-i}\\ &=f_n+(n-1)f_{n-1} \end{aligned}

g1=f1=1g_1=f_1=1

所以有:

fn={1n=1(n1)fn1n>1f_{n}=\begin{cases} 1&n=1\\ -(n-1)f_{n-1}&n>1 \end{cases}

那么 fn=(1)n1(n1)!f_n=(-1)^{n-1}(n-1)!

那么直接 dp 即可,时间复杂度 O(3n)O(3^n),代码如下:

#include <iostream>
#include <cstdio>

using namespace std;

typedef long long ll;

const int S=20,BS=1<<16,p=998244353;

int n;
ll m,a[S];
ll lcm[BS];
int f[S],dp[BS];

#define popc __builtin_popcount

inline ll gcd(ll x,ll y)
{
	if(x==0||y==0) return x+y;
	ll t=x%y;
	while(t!=0) x=y,y=t,t=x%y;
	return y;
}

inline ll getlcm(int st)
{
	ll res=1;
	for(int i=1;i<=n;i++)
	{
		if(st>>i-1&1)
		{
			ll g=gcd(res,a[i]);
			res/=g;
			if(res>m/a[i]) res=m+1;
			else res*=a[i];
		}
	}
	return res;
}

int main()
{
	scanf("%d%lld",&n,&m);
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
	for(int i=0;i<(1<<n);i++) lcm[i]=getlcm(i);
	f[1]=1;
	for(int i=2;i<=n;i++) f[i]=p-1ll*(i-1)*f[i-1]%p;
	dp[0]=1;
	for(int i=1;i<(1<<n);i++)
	{
		for(int j=i;j>0;j=(j-1)&i)
		{
			if((i&-i)!=(j&-j)) continue;
			dp[i]=(dp[i]+1ll*f[popc(j)]*((m/lcm[j])%p)%p*dp[i^j]%p)%p;
		}
	}
	printf("%d\n",dp[(1<<n)-1]);
	return 0;
}