CF1503E 2-Coloring 做题记录

输入正整数 n,mn,m,要求把一个 n×mn \times m 的棋盘染成蓝色和黄色,且每行恰好有一段蓝色的格子,每列恰好有一段黄色的格子。输出总方案数对 998244353998244353 取模的结果。

n,m2021n,m\le2021

CF *3100 虚高。

显然合法的棋盘是长这样的:

抽象一点:

大概就是上面一个山坡下面一个山坡,并且两个山坡没有接触。

先考虑怎么数山坡的方案数,显然可以把山坡劈成单调不降的两段,设 dpi,jdp_{i,j} 表示 ii 根柱子,最高的柱子(第 ii 根)高 jj 的方案数,显然有 dp0,0=1dp_{0,0}=1,转移为:

dpi,j=k=0jdpi1,kdp_{i,j}=\sum\limits_{k=0}^{j}dp_{i-1,k}

接下来考虑拼接形成山坡,由于上面的山坡和下面的山坡不能有接触,但是每一行都必须有某一个山坡经过,所以两个山坡最高的柱子之间只有两种情况:

(最高的柱子交错)

和:

(最高的柱子高度和为 nn

交错的情况很简单,显然两根最高的柱子一定是相邻的,那么把网格劈成两半,算出 dp2i,jdp2_{i,j} 表示上下都有 ii 根柱子,下面最高的那根柱子长度为 jj,上面最高的那根柱子长度 nj1\le n-j-1 的方案数,显然有 dp2i,j=dpi,j×k=0nj1dpi,kdp2_{i,j}=dp_{i,j}\times\sum\limits_{k=0}^{n-j-1}dp_{i,k}。然后枚举左边的那一半有多长和上面最高的柱子的高度,和右边拼起来即可:

ans1=i=1m1j=2n1dp2i,j×k=nj+1n1dp2mi,kans1=\sum\limits_{i=1}^{m-1}\sum\limits_{j=2}^{n-1}dp2_{i,j}\times\sum\limits_{k=n-j+1}^{n-1}dp2_{m-i,k}

高度和为 nn 的情况有点麻烦,不难发现上面的最高柱子区间和下面的最高柱子区间一定不相交,那么仍旧把网格劈成两半,求出 dp3i,jdp3_{i,j} 表示一共有 ii 个柱子可以是最高的,最高的柱子高度为 jj 时山坡的方案,显然有 dp3i,j=dpi,j×k=0j1dpmi,kdp3_{i,j}=dp_{i,j}\times\sum\limits_{k=0}^{j-1}dp_{m-i,k}。然后强制让上面最高柱子的区间右端点为 ii,枚举下面最高柱子的区间左端点 jj 即可:

ans2=i=1m1j=1n1dp3i,j×k=i+1mdp3mi+1,njans2=\sum\limits_{i=1}^{m-1}\sum\limits_{j=1}^{n-1}dp3_{i,j}\times\sum\limits_{k=i+1}^{m}dp3_{m-i+1,n-j}

答案即为 ans=2(ans1+ans2)ans=2(ans1+ans2) 因为上下可以反过来。

不难发现所有式子都可以用前缀和优化到 O(nm)O(nm),那么就做完了。

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const int S=3005,p=998244353;

int n,m;
int dp[S][S],sm[S][S];
int sm2[S][S];

inline void add(int &x,int y)
{
	x+=y;
	if(x>=p) x-=p;
}

int main()
{
	scanf("%d%d",&n,&m);
	dp[0][0]=1;
	for(int i=0;i<=n;i++) sm[0][i]=1;
	for(int i=1;i<=m;i++)
	{
		for(int j=0;j<=n;j++) dp[i][j]=sm[i-1][j];
		for(int j=0;j<=n;j++)
		{
			if(j>0) sm[i][j]=sm[i][j-1];
			add(sm[i][j],dp[i][j]);
		}
	}
	for(int i=1;i<=m;i++)
	{
		for(int j=n-1;j>=0;j--)
		{
			sm2[i][j]=sm2[i][j+1];
			add(sm2[i][j],1ll*dp[i][j]*sm[i][n-j-1]%p);
		}
	}
	int ans=0;
	for(int i=1;i<=m-1;i++)
	{
		for(int j=2;j<=n-1;j++)
		{
			int pre=1ll*dp[i][j]*sm[i][n-j-1]%p*sm2[m-i][n-j+1]%p;
			add(ans,pre);
		}
	}
	memset(sm2,0,sizeof(sm2));
	for(int i=1;i<=m;i++)
	{
		for(int j=1;j<=n-1;j++)
		{
			int pre=1ll*dp[i][j]*sm[m-i][j-1]%p;
			sm2[i][j]=sm2[i-1][j];
			add(sm2[i][j],pre);
		}
	}
	for(int i=1;i<=m;i++)
	{
		for(int j=1;j<=n-1;j++)
		{
			int pre=1ll*dp[i][j]*sm[m-i][j-1]%p*sm2[m-i][n-j]%p;
			add(ans,pre);
		}
	}
	ans=2ll*ans%p;
	printf("%d\n",ans);
	return 0;
}