给定正整数 ,求有多少个 的子集 ,满足对于任意 中的正整数 ,都存在一个或两个 的子集 满足 。对 取模。
。
相当于用 中的数做 01 背包后背包数组中 每个位置的值都是 或者 。
考虑从小到大加入 中的数,每次相当于将背包数组整体往后移位再对位相加。那么如果某一时刻出现了 个连续非 段,则它们之间的 就再也无法被覆盖到了。故每个时刻背包数组中非 的位置一定是一个前缀,且该前缀长度是 中所有数的和再 ( 的下标从 开始)。
那么设背包数组为 , 中数的和为 ,则一定有 。
考虑第一次让 中出现 的数 ,此前加入的数一定都是 的次幂。设 为最小的 满足 ,则此时 一定形如:
接下来每次操作新产生的 都来自于最后一段 和第一段 的叠合,直到 都非 ,此时显然只能再操作最多一次。问题是这一次操作中第一段 叠合的可能不再是最后一段 。
观察到 的性质很好(是回文串且砍掉最后一次操作叠合出的中间那段后,两边也是回文串),考虑建树,去掉开头的 个 后,每次操作叠合出的中间段(长度范围 )作为节点,向左右两边最后一次操作对应的节点连边。叶子节点对应 个 (第一次出现 的操作)。
这样我们就可以枚举是在哪个节点处爆 ,按层做背包即可统计答案。
复杂度:
- 枚举 带来一个 ;
- 一共有 个节点;
- 每个节点处需要做背包,复杂度是 的( 来自于调和级数);
- 所以总复杂度为 ,即 。
实现起来有很多细节,要考虑好每种情况的贡献应该算在什么节点上。
代码如下:
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
typedef long long ll;
const int S=1505,BS=10;
#define p 998244353
int n;
ll siz[BS+5];
int g[S],tmp[BS+5][S];
int ans;
inline void add(int &x,int y)
{
x+=y;
if(x>=p) x-=p;
}
void dfs(int a,int len,int hei,int cnt,ll sm)
{
if(sm>n+1) return;
if(hei>1)
{
{// left
memcpy(tmp[hei],g,sizeof(g));
if(cnt>1)
{
int ml=cnt-1;
for(int i=n+1;i>=0;i--)
{
for(int j=a;j<=a*2&&i+j*ml<=n+1;j++)
add(g[i+j*ml],g[i]);
g[i]=0;
}
}
dfs(a,len,hei-1,(cnt-1)*2+1,sm);
memcpy(g,tmp[hei],sizeof(g));
}
if(sm+a+siz[hei-1]<=n+1)
{// right
memcpy(tmp[hei],g,sizeof(g));
int ml=cnt;
for(int i=n+1;i>=0;i--)
{
for(int j=a;j<=a*2&&i+j*ml<=n+1;j++)
add(g[i+j*ml],g[i]);
g[i]=0;
}
dfs(a,len,hei-1,cnt*2,sm+a+siz[hei-1]);
memcpy(g,tmp[hei],sizeof(g));
}
}
if(a+sm+siz[hei-1]>n+1) return;
memcpy(tmp[hei],g,sizeof(g));
ll ml=(cnt-1)*2+1;
for(int i=hei-1;i>=1;i--,ml*=2)
{
int lb,rb;
if(i>1) lb=a,rb=a*2;
else lb=rb=len-a;
for(int j=n+1;j>=0;j--)
{
for(int k=lb;k<=rb&&j+k*ml<=n+1;k++)
add(g[j+k*ml],g[j]);
g[j]=0;
}
}
if(hei>1) // not leaf
{
for(int i=0;i<n+1-a;i++)
for(int x=a;x<=a*2&&a+i+x*(cnt-1)<n+1;x++)
{
int pos=a+i+x*(cnt-1);
int c1=x-a;
if(pos+x<n+1) continue;
if(cnt==1&&pos+c1>n+1) break;
add(ans,g[i]);
int r1=0;
if(cnt==1)
{
if(pos+x-c1<n+1)
r1=n+1-(pos+x-c1)-(x==a*2);
}
else
{
if(pos+c1>=n+1) r1=n+1-pos;
else if(x==a*2) r1=min(n+1-pos,a);
else if(pos+x-c1<n+1) r1=n+1-(pos+x-c1);
}
add(ans,1ll*r1*g[i]%p);
}
}
else
{
int ad=a+(len-a)*(cnt-1);
for(int i=0;i<n+1-ad;i++)
{
int pre=i+ad;
int x=len-a;
if(pre+x<n+1) continue;
add(ans,g[i]);
}
}
memcpy(g,tmp[hei],sizeof(g));
}
int main()
{
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
cin>>n;
ans=1;
for(int a=1;a<=n+1;a++)
{
int len=1;
while(len<=a) len<<=1;
if(len==a*2) continue;
siz[1]=len-a;
for(int i=2;i<=BS;i++) siz[i]=siz[i-1]*2+a;
g[0]=1;
int lst=ans;
dfs(a,len,BS,1,0);
// for(int i=1;i<=BS-3;i++)
// {
// int lst=ans;
// dfs(a,len,i,1,0);
// if(ans>lst) printf("%d: %d\n",i,ans-lst);
// }
// printf(">> %d %d %d : %d\n",a,len-a,a,ans-lst);
}
cout<<ans<<'\n';
return 0;
}