求所有由 种不同字符组成,长度为 的字符串有多少种不同的后缀数组,其中第 种字符至多出现 次。对 取模。
。
理解能力太差,直觉太弱。
考虑不等式链:
所有满足不等式链的字符串的后缀数组都是 ,而判断一个不等式链是否合法(有一个满足输入条件的字符串满足该不等式链)是简单的,仅需贪心:
- 对于 $\le $,尽量让其变为 ,实在不行再让其变为 ;
具体的,设 将 分成了 段,第 段长度为 ,贪心:
- 从  开始填字符,假设当前填到了字符 ,是第  段,则:
- 若 ,则 ;
- 否则 ;
 
我们现在构建了 后缀数组 到 不等式链 的 映射,并且可以判断一个不等式链是否合法。
但是不同的后缀数组有可能映射到相同的不等式链( 和 ),所以这是映射而不是单射。
那么问题就变为计算有多少个后缀数组可以映射到某一个不等式链。
但是直接去对一个不等式链计数后缀数组似乎是困难的,主要是后缀数组到不等式链的映射太奇怪了。
那么考虑抛弃题目限制,对于一个不等式链 ,构建一个满足其限制的 字符集大小最小的 后缀排序后的字符串(应用 对应的变换之后的)。显然是 个 , 个 这样依次拼接形成的字符串,不妨记其为 ,不难发现这是一个 单射。
那么对于满足后缀排序后等于 的不同的原字符串,它们对应的后缀数组一定不同。
反证法,若两个不同字符串对应的后缀数组相同,则对应的不等式链也相同,故 也相同。而又由于后缀数组相同,故 的逆变换相同,原来的字符串相同,矛盾.
所以 对应的后缀数组的个数 不多于 对应的原字符串的个数 。而显然有 。
考虑 算多了什么,即计算这些原字符串对应的后缀数组有多少个对应的不等式链不是 。那么有可能某些 变为了 $\le $ 或者某些 $\le $ 变为了 。注意到 $\le $ 变为 是不可能的,因为这会使得 不再合法,故仅有可能是某些 变为了 $\le $。
考虑容斥,钦定若干个 变为 $\le $。其实就是合并了一些相邻的段,而注意到 变为 的极大集合为 的情况会被它所有子集算到,故若钦定 个则容斥系数为 。
现在我们需要将原来的限制(有关每种字符数量的)加上。
那么考虑边 dp 边容斥,具体的,贪心判断不等式链合法性的过程最终会得到一个字符串,这个字符串和合法的不等式链是双射的,所以不妨对这个字符串 dp。
设 表示填了 这些字符,目前字符串长度为 ,末尾(未处理贡献的)段长度为 。那么有 ,答案为 。
考虑转移:
- 填完 还不够填满这一段:;
- 填了 个 后分段了(出现了 ):;
- 填了 个 后碰到了一个钦定从 变来的 $\le f_{i,j,k}\times (-1)\to f_{i+1,j+x,k+x}$;
注意到由于转移中 ,故第三个转移和第一个转移抵消了:
- ,;
- ,;
前缀和优化即可,时间复杂度 。
具体的:
记得特判 的字符。
代码如下:
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
const int S=505;
#define p 1000000007
int fra[S],inv[S];
int n,m,c[S];
int sm[S][S],f[2][S][S];
inline void add(int &x,int y)
{
	x+=y;
	if(x>=p) x-=p;
}
inline int qpow(int x,int y)
{
	int res=1;
	for(;y>0;y>>=1,x=1ll*x*x%p) res=y&1?1ll*res*x%p:res;
	return res;
}
int main()
{
	fra[0]=1;
	for(int i=1;i<=S-3;i++) fra[i]=1ll*fra[i-1]*i%p;
	inv[S-3]=qpow(fra[S-3],p-2);
	for(int i=S-3;i>=1;i--) inv[i-1]=1ll*inv[i]*i%p;
	scanf("%d%d",&n,&m);
	for(int i=1;i<=m;i++) scanf("%d",&c[i]);
	int ans=0;
	f[0][0][0]=1;
	for(int i=1;i<=m;i++)
	{
		int u=i&1,v=u^1;
		memset(f[u],0,sizeof(f[u]));
		if(c[i]==0)
		{
			memcpy(f[u],f[v],sizeof(f[u]));
			continue;
		}
		for(int j=0;j<=n;j++)
		{
			for(int k=0;k<=n;k++)
			{
				sm[j][k]=f[v][j][k];
				if(j>0&&k>0) add(sm[j][k],sm[j-1][k-1]);
			}
		}
		for(int j=1;j<=n;j++)
		{
			for(int x=1;x<=n&&x<=j;x++)
			{
				// x-c_i <= k <= x-1
				int lb=max(x-c[i],0),rb=x-1;
				int pre=sm[j-x+rb][rb];
				if(lb>0) add(pre,p-sm[j-x+lb-1][lb-1]);
				add(f[u][j][0],1ll*inv[x]*pre%p);
			}
		}
		for(int j=1;j<=n;j++)
		{
			for(int k=1;k<=n;k++)
			{
				add(f[u][j][k],p-sm[j-1][k-1]);
				if(j>=c[i]&&k>=c[i])
					add(f[u][j][k],sm[j-c[i]][k-c[i]]);
			}
		}
		add(ans,f[u][n][0]);
	}
	printf("%d\n",1ll*fra[n]*ans%p);
	return 0;
}
