求所有由 种不同字符组成,长度为 的字符串有多少种不同的后缀数组,其中第 种字符至多出现 次。对 取模。
。
理解能力太差,直觉太弱。
考虑不等式链:
所有满足不等式链的字符串的后缀数组都是 ,而判断一个不等式链是否合法(有一个满足输入条件的字符串满足该不等式链)是简单的,仅需贪心:
- 对于 $\le $,尽量让其变为 ,实在不行再让其变为 ;
具体的,设 将 分成了 段,第 段长度为 ,贪心:
- 从 开始填字符,假设当前填到了字符 ,是第 段,则:
- 若 ,则 ;
- 否则 ;
我们现在构建了 后缀数组 到 不等式链 的 映射,并且可以判断一个不等式链是否合法。
但是不同的后缀数组有可能映射到相同的不等式链( 和 ),所以这是映射而不是单射。
那么问题就变为计算有多少个后缀数组可以映射到某一个不等式链。
但是直接去对一个不等式链计数后缀数组似乎是困难的,主要是后缀数组到不等式链的映射太奇怪了。
那么考虑抛弃题目限制,对于一个不等式链 ,构建一个满足其限制的 字符集大小最小的 后缀排序后的字符串(应用 对应的变换之后的)。显然是 个 , 个 这样依次拼接形成的字符串,不妨记其为 ,不难发现这是一个 单射。
那么对于满足后缀排序后等于 的不同的原字符串,它们对应的后缀数组一定不同。
反证法,若两个不同字符串对应的后缀数组相同,则对应的不等式链也相同,故 也相同。而又由于后缀数组相同,故 的逆变换相同,原来的字符串相同,矛盾.
所以 对应的后缀数组的个数 不多于 对应的原字符串的个数 。而显然有 。
考虑 算多了什么,即计算这些原字符串对应的后缀数组有多少个对应的不等式链不是 。那么有可能某些 变为了 $\le $ 或者某些 $\le $ 变为了 。注意到 $\le $ 变为 是不可能的,因为这会使得 不再合法,故仅有可能是某些 变为了 $\le $。
考虑容斥,钦定若干个 变为 $\le $。其实就是合并了一些相邻的段,而注意到 变为 的极大集合为 的情况会被它所有子集算到,故若钦定 个则容斥系数为 。
现在我们需要将原来的限制(有关每种字符数量的)加上。
那么考虑边 dp 边容斥,具体的,贪心判断不等式链合法性的过程最终会得到一个字符串,这个字符串和合法的不等式链是双射的,所以不妨对这个字符串 dp。
设 表示填了 这些字符,目前字符串长度为 ,末尾(未处理贡献的)段长度为 。那么有 ,答案为 。
考虑转移:
- 填完 还不够填满这一段:;
- 填了 个 后分段了(出现了 ):;
- 填了 个 后碰到了一个钦定从 变来的 $\le f_{i,j,k}\times (-1)\to f_{i+1,j+x,k+x}$;
注意到由于转移中 ,故第三个转移和第一个转移抵消了:
- ,;
- ,;
前缀和优化即可,时间复杂度 。
具体的:
记得特判 的字符。
代码如下:
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
const int S=505;
#define p 1000000007
int fra[S],inv[S];
int n,m,c[S];
int sm[S][S],f[2][S][S];
inline void add(int &x,int y)
{
x+=y;
if(x>=p) x-=p;
}
inline int qpow(int x,int y)
{
int res=1;
for(;y>0;y>>=1,x=1ll*x*x%p) res=y&1?1ll*res*x%p:res;
return res;
}
int main()
{
fra[0]=1;
for(int i=1;i<=S-3;i++) fra[i]=1ll*fra[i-1]*i%p;
inv[S-3]=qpow(fra[S-3],p-2);
for(int i=S-3;i>=1;i--) inv[i-1]=1ll*inv[i]*i%p;
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++) scanf("%d",&c[i]);
int ans=0;
f[0][0][0]=1;
for(int i=1;i<=m;i++)
{
int u=i&1,v=u^1;
memset(f[u],0,sizeof(f[u]));
if(c[i]==0)
{
memcpy(f[u],f[v],sizeof(f[u]));
continue;
}
for(int j=0;j<=n;j++)
{
for(int k=0;k<=n;k++)
{
sm[j][k]=f[v][j][k];
if(j>0&&k>0) add(sm[j][k],sm[j-1][k-1]);
}
}
for(int j=1;j<=n;j++)
{
for(int x=1;x<=n&&x<=j;x++)
{
// x-c_i <= k <= x-1
int lb=max(x-c[i],0),rb=x-1;
int pre=sm[j-x+rb][rb];
if(lb>0) add(pre,p-sm[j-x+lb-1][lb-1]);
add(f[u][j][0],1ll*inv[x]*pre%p);
}
}
for(int j=1;j<=n;j++)
{
for(int k=1;k<=n;k++)
{
add(f[u][j][k],p-sm[j-1][k-1]);
if(j>=c[i]&&k>=c[i])
add(f[u][j][k],sm[j-c[i]][k-c[i]]);
}
}
add(ans,f[u][n][0]);
}
printf("%d\n",1ll*fra[n]*ans%p);
return 0;
}