后缀数组学习笔记

后缀数组又名 SA,它是一种十分实用的字符串处理工具,在很多地方能替代后缀自动机。

后缀数组的作用是O(nlogn)O(n\log n) 时间复杂度内求出一个字符串 SS 的所有后缀的排名

为了下文表述方便,我们先规定几个记号:

bib_i 为字符串 SSii 开始的后缀,str[l,r]str[l,r] 表示 strstr 的第 ll 到第 rr 这段区间里的字符组成的字符串,saisa_i 为排名为 ii 的后缀的起始位置,rkirk_ibib_i 的排名。很明显,sasarkrk 互为逆操作。

我们可以先考虑一个弱智问题:

对所有后缀的第一个字符排序。

很明显,这个东西可以用桶排序来解决,但是注意会有并列的情况

对于并列的情况,为了方便以后的处理,我们rkrk 相同但 sasa 不相同

代码如下:

struct node
{
	int x,pos;
}tmp[1000005],tmp2[1000005];

int n;
char s[1000005];
int sa[1000005],rk[1000005];
int tot[1000005];

inline void fastsort(int w) // w 为值域 
{
	for(int i=0;i<=w;i++) tot[i]=0;
	// 桶排序,tot[i] 表示权值 <= i 的元素个数 
	for(int i=1;i<=n;i++) tot[tmp[i].x]++;
	for(int i=1;i<=w;i++) tot[i]+=tot[i-1];
	// 获得排序后的数组,并复制到 tmp
	// 对于下面这个 for 的解释:
	// 由于 x 可能相同,所以权值为 tmp[i].x 的点可能有多个
	// 那么它们的排名区间是 [  tot[tmp[i].x-1]+1 , tot[tmp[i].x]  ] 
	// 由于排序最好是稳定的(即相同元素不改变相对位置),所以我们要从后往前跑循环,即从后往前插入到排名区间里 
	for(int i=n;i>=1;i--) tmp2[tot[tmp[i].x]--]=tmp[i];
	for(int i=1;i<=n;i++) tmp[i]=tmp2[i]; // 复制 
}

inline void sasort() // 弱智问题的解法 
{
	for(int i=1;i<=n;i++) tmp[i]=(node){(int)s[i],i}; // 注意要记录位置 
	fastsort(256);
	for(int i=1;i<=n;i++) sa[i]=tmp[i].pos; // 不考虑并列的话排名为 i 的后缀就是 b[tmp[i].pos]
	for(int i=1;i<=n;i++) rk[tmp[i].pos]=rk[tmp[i-1].pos]+(tmp[i].x!=tmp[i-1].x); // 注意只有元素不一样排名才增加 
}

解决完这个弱智问题后,我们来看一个进阶版的问题:

对所有后缀的前两个字符排序,即对所有关键字 <Si,Si+1><S_i,S_{i+1}> 排序。

这个问题使用桶排序有点难解决,但还是可做的。

首先对第二关键字排序,求出 tpitp_i 表示在第二关键字中排名为 ii 的元素的位置,并列则按位置排序

然后求出第一关键字 x\le x 的元素个数 totxtot_x

通过弱智问题代码注释中的结论,我们知道第一关键字为 xx 的排名区间为:

[totx1+1,totx][tot_{x-1}+1,tot_x]

又因为第二关键字中排名为 ii 的元素的位置为 tpitp_i,所以我们可以从后往前遍历 tptp 数组,把 tpitp_i 从后往前依次加进元素 ii 对应的排名区间内,这样不但保证排序正确,还能保证这个排序是稳定的

代码如下:

struct node
{
	int x,y,pos;
}tmp[1000005],tmp2[1000005];

int n;
char s[1000005];
int sa[1000005],rk[1000005];
int tp[1000005],tot[1000005];

inline void fastsort(int w)
{ 
	for(int i=0;i<=w;i++) tot[i]=0;
	// 对第二关键字排序,求出 tp[i] 表示第二关键字中排名第 i 的关键字的位置(相当于第二关键字的 sa) 
	for(int i=1;i<=n;i++) tot[tmp[i].y]++;
	for(int i=1;i<=w;i++) tot[i]+=tot[i-1];
	for(int i=n;i>=1;i--) tp[tot[tmp[i].y]--]=i;
	for(int i=0;i<=w;i++) tot[i]=0;
	// 对第一关键字排序 
	for(int i=1;i<=n;i++) tot[tmp[i].x]++;
	for(int i=1;i<=w;i++) tot[i]+=tot[i-1];
	// 结合起来 
	for(int i=n;i>=1;i--) tmp2[tot[tmp[tp[i]].x]--]=tmp[tp[i]];
	for(int i=1;i<=n;i++) tmp[i]=tmp2[i];
}

inline void sasort() // 进阶问题的解法
{
	for(int i=1;i<=n-1;i++) tmp[i]=(node){(int)s[i],(int)s[i+1],i}; // 只有 n-1 个元素有两个关键字 
	tmp[n]=(node){(int)s[n],0,i}; // 没有第二关键字,那么令它为 0 
	fastsort(256);
	for(int i=1;i<=n;i++) sa[i]=tmp[i].pos;
	for(int i=1;i<=n;i++) rk[tmp[i].pos]=rk[tmp[i-1].pos]+(tmp[i].x!=tmp[i-1].x||tmp[i].y!=tmp[i-1].y);
}

我们再来考虑一个问题:

对所有后缀的前四个字符排序,即对所有关键字 <Si,Si+1,Si+2,Si+3><S_i,S_{i+1},S_{i+2},S_{i+3}> 排序。

对于这个问题,我们并不需要重新写一个排序,因为那样太麻烦了。我们只需要先对 <Si,Si+1><S_i,S_{i+1}> 排序,再对 <rki,rki+2><rk_i,rk_{i+2}> 排序即可。因为和 bi[3,4]b_i[3,4] 最相似的长度为 22 的前缀是 bi+2[1,2]b_{i+2}[1,2],我们就可以把 bi[1,2]b_i[1,2]bi+2[3,4]b_{i+2}[3,4] 拼接起来排序。

所以,我们可以使用倍增来排序所有后缀

模板题代码如下:

// P3809
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>

using namespace std;

struct node
{
	int x,y,pos;
}tmp[1000005],tmp2[1000005];

int n;
char s[1000005];
int sa[1000005],rk[1000005];
int tp[1000005],tot[1000005];

inline void fastsort(int w)
{
	for(int i=0;i<=w;i++) tot[i]=0;
	// 对第二关键字排序,求出 tp[i] 表示第二关键字中排名第 i 的关键字的位置(相当于第二关键字的 sa) 
	for(int i=1;i<=n;i++) tot[tmp[i].y]++;
	for(int i=1;i<=w;i++) tot[i]+=tot[i-1];
	for(int i=n;i>=1;i--) tp[tot[tmp[i].y]--]=i;
	for(int i=0;i<=w;i++) tot[i]=0;
	// 对第一关键字排序 
	for(int i=1;i<=n;i++) tot[tmp[i].x]++;
	for(int i=1;i<=w;i++) tot[i]+=tot[i-1];
	// 结合起来 
	for(int i=n;i>=1;i--) tmp2[tot[tmp[tp[i]].x]--]=tmp[tp[i]];
	for(int i=1;i<=n;i++) tmp[i]=tmp2[i];
}

inline void sasort()
{
	// 对第一个字符排序 
	for(int i=1;i<=n;i++) tmp[i]=(node){(int)s[i],0,i};
	fastsort(256);
	for(int i=1;i<=n;i++) sa[i]=tmp[i].pos;
	for(int i=1;i<=n;i++) rk[tmp[i].pos]=rk[tmp[i-1].pos]+(tmp[i].x!=tmp[i-1].x||tmp[i].y!=tmp[i-1].y);
	// 倍增 
	for(int p=1;p<=n;p<<=1)
	{
		for(int i=1;i<=n-p;i++) tmp[i]=(node){rk[i],rk[i+p],i};
		// b[i] 已经根据前 p 个字符排过序了 
		// 所以和 b[i] 的第 p+1 ~ p+p 个字符最相似的是 b[rk[i+p]] 的前 p 个字符 
		for(int i=n-p+1;i<=n;i++) tmp[i]=(node){rk[i],0,i}; // 没有第二关键字,那么设为 0 
		fastsort(rk[sa[n]]); // 排序,注意值域是 rk[sa[n]] 即最大的 rk 值 
		for(int i=1;i<=n;i++) sa[i]=tmp[i].pos;
		for(int i=1;i<=n;i++) rk[tmp[i].pos]=rk[tmp[i-1].pos]+(tmp[i].x!=tmp[i-1].x||tmp[i].y!=tmp[i-1].y);
		if(rk[sa[n]]>=n) break;
	}
}

int main()
{
	scanf("%s",s+1);
	n=strlen(s+1);
	sasort();
	for(int i=1;i<=n;i++)
	{
		printf("%d ",sa[i]);
	}
	printf("\n");
	return 0;
}

接下来我们考虑一个很经典的问题:(P2408 不同子串个数

求出某个字符串的不同子串个数。

我们bsaib_{sa_i}bsai1b_{sa_i-1} 的最长公共前缀长度为 heightiheight_ibib_{i}bsarki1b_{sa_{rk_i-1}} 的最长公共前缀长度为 hih_i。很明显 hi=heightrkih_i=height_{rk_i}

那么很显然可以用暴力,不断让 hih_i 增加直到 hih_i 大于两个后缀长度中最小的那个或者两个后缀的第 hi+1h_i+1 个字符不同

但是有个很巧妙的柿子:

hihi11h_i\ge h_{i-1}-1

证明如下:(转载自这里

最后 bib_i 对答案的贡献即为 (ni+1)hi(n-i+1)-h_i

题目代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

struct node
{
	int x,y,pos;
}tmp[100005],tmp2[100005];

int n;
char s[100005];
int tot[100005],tp[100005];
int sa[100005],rk[100005];
int h[100005];

inline void fastsort(int w)
{
	for(int i=0;i<=w;i++) tot[i]=0;
	for(int i=1;i<=n;i++) tot[tmp[i].y]++;
	for(int i=1;i<=w;i++) tot[i]+=tot[i-1];
	for(int i=n;i>=1;i--) tp[tot[tmp[i].y]--]=i;
	for(int i=0;i<=w;i++) tot[i]=0;
	for(int i=1;i<=n;i++) tot[tmp[i].x]++;
	for(int i=1;i<=w;i++) tot[i]+=tot[i-1];
	for(int i=n;i>=1;i--) tmp2[tot[tmp[tp[i]].x]--]=tmp[tp[i]];
	for(int i=1;i<=n;i++) tmp[i]=tmp2[i];
}

inline void sasort()
{
	for(int i=1;i<=n;i++) tmp[i]=(node){(int)s[i],0,i};
	fastsort(256);
	for(int i=1;i<=n;i++) sa[i]=tmp[i].pos;
	for(int i=1;i<=n;i++) rk[tmp[i].pos]=rk[tmp[i-1].pos]+(tmp[i].x!=tmp[i-1].x||tmp[i].y!=tmp[i-1].y);
	for(int p=1;p<=n;p<<=1)
	{
		for(int i=1;i<=n-p;i++) tmp[i]=(node){rk[i],rk[i+p],i};
		for(int i=n-p+1;i<=n;i++) tmp[i]=(node){rk[i],0,i};
		fastsort(n);
		for(int i=1;i<=n;i++) sa[i]=tmp[i].pos;
		for(int i=1;i<=n;i++) rk[tmp[i].pos]=rk[tmp[i-1].pos]+(tmp[i].x!=tmp[i-1].x||tmp[i].y!=tmp[i-1].y);
		if(rk[sa[n]]>=n) break;
	}
}

inline void geth()
{
	for(int i=1;i<=n;i++)
	{
		int k=max(h[i-1]-1,0);
		int pos=sa[rk[i]-1];
		while(i+k<=n&&pos+k<=n&&s[i+k]==s[pos+k])
		{
			k++;
		}
		h[i]=k;
	}
}

int main()
{
	scanf("%d",&n);
	scanf("%s",s+1);
	sasort();
	geth();
	long long ans=0;
	for(int i=1;i<=n;i++)
	{
		ans+=(n-i+1)-h[i];
	}
	printf("%lld\n",ans);
	return 0;
}

练习题目

P3763 [TJOI2017]DNA

P2463 [SDOI2008] Sandy 的卡片

P2852 [USACO06DEC]Milk Patterns G

P3181 [HAOI2016]找相同字符