ARC143F Counting Subsets 做题记录

给定正整数 nn,求有多少个 {1,2,3,,n}\{1,2,3,\dots,n\} 的子集 SS,满足对于任意 [1,n][1,n] 中的正整数 xx,都存在一个或两个 SS 的子集 TT 满足 x=yTyx=\sum\limits_{y\in T}y。对 998244353998244353 取模。

1n15001\le n\le 1500

相当于用 SS 中的数做 01 背包后背包数组中 [1,n][1,n] 每个位置的值都是 11 或者 22

考虑从小到大加入 SS 中的数,每次相当于将背包数组整体往后移位再对位相加。那么如果某一时刻出现了 >1>1 个连续非 00 段,则它们之间的 00 就再也无法被覆盖到了。故每个时刻背包数组中非 00 的位置一定是一个前缀,且该前缀长度是 SS 中所有数的和再 +1+1bb 的下标从 00 开始)。

那么设背包数组为 bbSS 中数的和为 smsm,则一定有 bx=bsmxb_x=b_{sm-x}

考虑第一次让 bb 中出现 22 的数 aa,此前加入的数一定都是 22 的次幂。设 lenlen 为最小的 2x2^x 满足 len>alen>a,则此时 bb 一定形如:

111a 个 1222lena 个 2111a 个 1\underset{\text{$a$ 个 $1$}}{\underbrace{11\dots1}}\underset{\text{$len-a$ 个 $2$}}{\underbrace{22\dots2}}\underset{\text{$a$ 个 $1$}}{\underbrace{11\dots1}}

接下来每次操作新产生的 22 都来自于最后一段 11 和第一段 11 的叠合,直到 b[0,n]b_{[0,n]} 都非 00,此时显然只能再操作最多一次。问题是这一次操作中第一段 11 叠合的可能不再是最后一段 11

观察到 bb 的性质很好(是回文串且砍掉最后一次操作叠合出的中间那段后,两边也是回文串),考虑建树,去掉开头的 aa11 后,每次操作叠合出的中间段(长度范围 [a,2a][a,2a])作为节点,向左右两边最后一次操作对应的节点连边。叶子节点对应 lenalen-a22(第一次出现 22 的操作)。

这样我们就可以枚举是在哪个节点处爆 [0,n][0,n],按层做背包即可统计答案。

复杂度:

  • 枚举 aa 带来一个 O(n)O(n)
  • 一共有 O(na)O(\frac{n}{a}) 个节点;
  • 每个节点处需要做背包,复杂度是 O(nlnn)O(n\ln n) 的(ln\ln 来自于调和级数);
  • 所以总复杂度为 aO(na×nlnn)\sum\limits_{a}O(\frac{n}{a}\times n\ln n),即 O(n2ln2n)O(n^2\ln^2n)

实现起来有很多细节,要考虑好每种情况的贡献应该算在什么节点上。

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

typedef long long ll;
const int S=1505,BS=10;

#define p 998244353

int n;
ll siz[BS+5];
int g[S],tmp[BS+5][S];
int ans;

inline void add(int &x,int y)
{
	x+=y;
	if(x>=p) x-=p;
}

void dfs(int a,int len,int hei,int cnt,ll sm)
{
	if(sm>n+1) return;
	if(hei>1)
	{
		{// left
			memcpy(tmp[hei],g,sizeof(g));
			if(cnt>1)
			{
				int ml=cnt-1;
				for(int i=n+1;i>=0;i--)
				{
					for(int j=a;j<=a*2&&i+j*ml<=n+1;j++)
						add(g[i+j*ml],g[i]);
					g[i]=0;
				}
			}
			dfs(a,len,hei-1,(cnt-1)*2+1,sm);
			memcpy(g,tmp[hei],sizeof(g));
		}
		if(sm+a+siz[hei-1]<=n+1)
		{// right
			memcpy(tmp[hei],g,sizeof(g));
			int ml=cnt;
			for(int i=n+1;i>=0;i--)
			{
				for(int j=a;j<=a*2&&i+j*ml<=n+1;j++)
					add(g[i+j*ml],g[i]);
				g[i]=0;
			}
			dfs(a,len,hei-1,cnt*2,sm+a+siz[hei-1]);
			memcpy(g,tmp[hei],sizeof(g));
		}
	}
	if(a+sm+siz[hei-1]>n+1) return;
	memcpy(tmp[hei],g,sizeof(g));
	ll ml=(cnt-1)*2+1;
	for(int i=hei-1;i>=1;i--,ml*=2)
	{
		int lb,rb;
		if(i>1) lb=a,rb=a*2;
		else lb=rb=len-a;
		for(int j=n+1;j>=0;j--)
		{
			for(int k=lb;k<=rb&&j+k*ml<=n+1;k++)
				add(g[j+k*ml],g[j]);
			g[j]=0;
		}
	}
	if(hei>1) // not leaf
	{
		for(int i=0;i<n+1-a;i++)
			for(int x=a;x<=a*2&&a+i+x*(cnt-1)<n+1;x++)
			{
				int pos=a+i+x*(cnt-1);
				int c1=x-a;
				if(pos+x<n+1) continue;
				if(cnt==1&&pos+c1>n+1) break;
				add(ans,g[i]);
				int r1=0;
				if(cnt==1)
				{
					if(pos+x-c1<n+1)
						r1=n+1-(pos+x-c1)-(x==a*2);
				}
				else
				{
					if(pos+c1>=n+1) r1=n+1-pos;
					else if(x==a*2) r1=min(n+1-pos,a);
					else if(pos+x-c1<n+1) r1=n+1-(pos+x-c1);
				}
				add(ans,1ll*r1*g[i]%p);
			}
	}
	else
	{
		int ad=a+(len-a)*(cnt-1);
		for(int i=0;i<n+1-ad;i++)
		{
			int pre=i+ad;
			int x=len-a;
			if(pre+x<n+1) continue;
			add(ans,g[i]);
		}
	}
	memcpy(g,tmp[hei],sizeof(g));
}

int main()
{
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n;
	ans=1;
	for(int a=1;a<=n+1;a++)
	{
		int len=1;
		while(len<=a) len<<=1;
		if(len==a*2) continue;
		siz[1]=len-a;
		for(int i=2;i<=BS;i++) siz[i]=siz[i-1]*2+a;
		g[0]=1;
		int lst=ans;
		dfs(a,len,BS,1,0);
		// for(int i=1;i<=BS-3;i++)
		// {
			// int lst=ans;
			// dfs(a,len,i,1,0);
			// if(ans>lst) printf("%d: %d\n",i,ans-lst);
		// }
		// printf(">> %d %d %d : %d\n",a,len-a,a,ans-lst);
	}
	cout<<ans<<'\n';
	return 0;
}