CF1784F Minimums or Medians 做题记录

你有一个包含 12n1 \sim 2n2n2n 个整数的集合 SS。你必须执行恰好 kk1kn1061 \le k \le n \le 10^6)次操作,每个操作都是以下两种其中之一:

  • SS 中第 1,21, 2 个元素删去。

  • SS 中第 S2,S2+1\frac{|S|}2, \frac{|S|}2 + 1 个元素删去。(显然 S|S| 一直是偶数,所以 S2\frac{|S|}2 也一定是整数)

请你统计,通过这些操作可以获得多少个本质不同的最终集合 SS?答案对 998244353998244353 取模。

Jerry Wen 定理。

考虑什么样的序列有可能得到(找必要条件),不妨关心被删除的数。

考虑被删除的数的连续段,由于每次是删除相邻两个数,所以每个连续段的长度一定是偶数

不难发现,第一个连续段一定可以被操作一删掉,为了防止算重,不妨钦定第一个连续段是被操作一删掉的。而之后的连续段一定不能被操作一删掉,所以除了第一个之外的连续段都是被操作二删掉的

继续观察,发现第 ii 次操作时中位数中较大的那个一定是 n+in+i,所以被删除的最大的数不超过 n+kn+k

发现操作二删除连续段有两种方式:

  • ((()))((())) 型,即不进行操作一,只进行操作二;
  • ()()()()()() 型,即一二交替;

发现只有第二个连续段可能使用第一种方式删除,且只有这种方式能删去 n\le n 的数。所以若第二个连续段中第一个数为 pppnp\le n,则 [p,2n+1p][p,2n+1-p] 一定都在第二个连续段中

考虑对于所有满足上述加粗条件的连续段集合 SS,一定能构造出合法的操作序列:对于 SS 中的每个连续段的右半部分,若中位数中较大的一个在其中则执行操作二,否则执行操作一。

那么只要数这样的 SS 即可。

为了计数方便,设 f(n,m)f(n,m) 表示从 nn 个数中选出 2m2m 个数且这些数构成的所有连续段长度均为偶数的方案数。这相当于从 nmn-m 个数中选出 mm 个,每数再扩充成两个,所以 f(n,m)=(nmm)f(n,m)=\binom{n-m}{m}

枚举第一段的长度 2i2i

  • 首先有可能剩下的所有连续段都是一二交替地删除的,即第二段的第一个数 >n>n,这部分的方案数是 f(min(k,k2i+n1),ki)f(\min(k,k-2i+n-1),k-i)

  • 2i<n2i<n 则第二段的第一个数有可能 n\le n,那么需要枚举 pp 即第二段的第一个数,方案数是:

    p=2i+2nf(kn+p1,kin1+p)=p=2i+2nf(kn+p1,kin+p1)=p=2i+2n(ikin+p1)\begin{aligned} \sum\limits_{p=2i+2}^nf(k-n+p-1,k-i-n-1+p)&=\sum\limits_{p=2i+2}^nf(k-n+p-1,k-i-n+p-1)\\ &=\sum\limits_{p=2i+2}^n\binom{i}{k-i-n+p-1}\\ \end{aligned}

    发现相当于是求 j=LiRi(ij)\sum\limits_{j=L_i}^{R_i} \binom{i}{j},注意到 [Li,Ri][L_i,R_i][Li+1,Ri+1][L_{i+1},R_{i+1}] 两端指针的移动是 O(1)O(1) 的,并且若已知 sml=j=LiRi(i1j)sml=\sum\limits_{j=L_i}^{R_i} \binom{i-1}{j} 也可以快速求出 smp=j=LiRi(ij)smp=\sum\limits_{j=L_i}^{R_i} \binom{i}{j} 因为 smp=2sml+(i1Li1)(i1Ri)smp=2sml+\binom{i-1}{L_i-1}-\binom{i-1}{R_i},那么双指针维护即可;

时间复杂度 O(n)O(n),代码如下:

// Problem: Minimums or Medians
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/CF1784F
// Memory Limit: 500 MB
// Time Limit: 4000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <iostream>
#include <cstdio>

int mod;

// fastmod
struct mint
{
	int val;
	inline mint(){val=0;}
	operator int(){return val;}
	template<typename T>inline mint(T x){val=(x%mod+mod)%mod;}
	inline mint operator-(){return mod-val;}
	template<typename T>inline mint operator^(T b)
	{
		mint tmp=*this,res=1;
		for(;b>0;b>>=1,tmp*=tmp) res=b&1?res*tmp:res;
		return res; 
	}
	template<typename T>inline mint operator^=(T b){return *this=*this^b;}
	inline mint inv(){return this->operator^(mod-2);}
	inline mint operator+(mint b){return val+b.val-(val+b.val>=mod?mod:0);}
	inline mint operator-(mint b){return val-b.val+(val-b.val<0?mod:0);}
	inline mint operator*(mint b){return 1ll*val*b.val%mod;}
	inline mint operator/(mint b){return 1ll*val*b.inv().val%mod;}
	inline mint operator+=(mint b){return val=val+b.val-(val+b.val>=mod?mod:0);}
	inline mint operator-=(mint b){return val=val-b.val+(val-b.val<0?mod:0);}
	inline mint operator*=(mint b){return val=1ll*val*b.val%mod;}
	inline mint operator/=(mint b){return val=1ll*val*b.inv().val%mod;}
	template<typename T>inline mint operator+(T b){return this->operator+(mint(b));};
	template<typename T>inline mint operator-(T b){return this->operator-(mint(b));};
	template<typename T>inline mint operator*(T b){return this->operator*(mint(b));};
	template<typename T>inline mint operator/(T b){return this->operator/(mint(b));};
	template<typename T>inline mint operator+=(T b){return this->operator+=(mint(b));};
	template<typename T>inline mint operator-=(T b){return this->operator-=(mint(b));};
	template<typename T>inline mint operator*=(T b){return this->operator*=(mint(b));};
	template<typename T>inline mint operator/=(T b){return this->operator/=(mint(b));};
};

using namespace std;

const int S=10000005;

mint fra[S],inv[S];
int n,k;

inline mint C(int n,int m)
{
	if(n<0||m<0||n<m) return 0;
	return fra[n]*inv[n-m]*inv[m];
}

inline mint f(int n,int m)
{
	return C(n-m,m);
}

int main()
{
	mod=998244353;
	fra[0]=1;
	for(int i=1;i<=S-3;i++) fra[i]=fra[i-1]*i;
	inv[S-3]=fra[S-3].inv();
	for(int i=S-3;i>=1;i--) inv[i-1]=inv[i]*i;
	scanf("%d%d",&n,&k);
	mint ans=0;
	int lb=1,rb=0,pr=-1;
	mint sum=0;
	for(int i=0;i<=k;i++)
	{
		ans+=i==n?(mint)1:f(min(k,k-i*2+n-1),k-i);
		if(i*2<n)
		{
			int L=k-n+i+1,R=k-i-1;
			if(pr>=0) sum=sum*2+C(pr,lb-1)-C(pr,rb);
			pr++;
			while(lb>L) sum+=C(pr,--lb);
			while(rb<R) sum+=C(pr,++rb);
			while(lb<L) sum-=C(pr,lb++);
			while(rb>R) sum-=C(pr,rb--);
			ans+=sum;
		}
	}
	printf("%d\n",ans);
	return 0;
}