Trie 学习笔记

Trie,即字典树,也就是一颗功能上很像字典的树。

Trie 经常用于维护一些关于字符串的东西,例如给定一些字符串,询问给定的字符串里有没有 ss。又例如给定一些字符串,询问给定的字符串里有多少个前缀为 ss

Trie 的主要思路是以空间换时间,每一条边上都有一个字符类型的权值,根节点到节点 xx 的路径上所有字符拼接起来便是 xx 所代表的的字符串。

注意 Trie 的节点里还要保存它所表示的字符串是不是给出的字符串。

经典例题

P2580 于是他错误的点名开始了

完整代码:

#include <iostream>
#include <cstdio>

using namespace std;

int n,m;
char s[105];
int sons[500005][26]; // 大小是 50*10000 
int cnt;
bool vis[500005],endd[500005];

void ins(int u,string str)
{
	if(str.empty()) // 搞完了,标记一下这个点代表的字符串是给定的
	{
		endd[u]=true;
		return;
	}
	if(sons[u][str[0]-'a']==0)
	{
		sons[u][str[0]-'a']=++cnt; // 没有这个儿子,新开一个点
	}
	ins(sons[u][str[0]-'a'],str.substr(1)); // 递归下去
}

int que(int u,string str)
{
	if(str.empty()) // 搞完了
	{
		if(endd[u]) // 是给定的字符串
		{
			if(vis[u]) // 访问过
			{
				return 2;
			}
			vis[u]=true;
			return 1;
		}
		return 0;
	}
	if(sons[u][str[0]-'a']==0) // 没有节点可以代表 str
	{
		return 0;
	}
	return que(sons[u][str[0]-'a'],str.substr(1)); // 递归下去
}

int main()
{
	scanf("%d",&n);
	cnt=1;
	for(int i=1;i<=n;i++)
	{
		scanf("%s",s);
		ins(1,s);
	}
	scanf("%d",&m);
	while(m--)
	{
		scanf("%s",s);
		int res=que(1,s);
		puts(res==1?"OK":(res==2?"REPEAT":"WRONG"));
	}
	return 0;
}

LOJ2742 销售基因链(JOI Open 2016 T2 「RNA 鎖の販売 / Selling RNA Strands」)

JOI 原题

首先对所有字符串和其翻转串建出 Trie 树,这样问题就变成了求同时在两棵树的两个子树中的字符串的个数。

考虑 dfs 序,把子树映射为区间之后就变成了二维数点问题,可以直接离线下来做。

代码如下:

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>

using namespace std;

const int S=5000005;

struct node
{
	int x,y;
}a[S],que[S];

struct node2
{
	int x,y,id,tpe;
};

int n,q;
char str[S];
int cnt1=1,son1[S][4];
int cnt2=1,son2[S][4];
vector<int> idx1[S],idx2[S];
int tt1,dfn1[S];
int tt2,dfn2[S];
int lx[S],rx[S],ly[S],ry[S];
int m;
node2 pts[S];
int c[S],res[S];

inline int id(char x)
{
	return x=='A'?0:(x=='U'?1:(x=='G'?2:3));
}

void ins1(int idd)
{
	int u=1,n=strlen(str+1);
	for(int i=1;i<=n;i++)
	{
		int x=id(str[i]);
		if(son1[u][x]==0) son1[u][x]=++cnt1;
		u=son1[u][x];
	}
	idx1[u].push_back(idd);
}

void ins2(int idd)
{
	int u=1,n=strlen(str+1);
	for(int i=1;i<=n;i++)
	{
		int x=id(str[i]);
		if(son2[u][x]==0) son2[u][x]=++cnt2;
		u=son2[u][x];
	}
	idx2[u].push_back(idd);
}

void dfs1(int u)
{
	lx[u]=dfn1[u]=++tt1;
	for(int i:idx1[u]) a[i].x=tt1;
	for(int i=0;i<4;i++) if(son1[u][i]!=0) dfs1(son1[u][i]); 
	rx[u]=tt1;
}

void dfs2(int u)
{
	ly[u]=dfn2[u]=++tt2;
	for(int i:idx2[u]) a[i].y=tt2;
	for(int i=0;i<4;i++) if(son2[u][i]!=0) dfs2(son2[u][i]);
	ry[u]=tt2;
}

int got1()
{
	int u=1,n=strlen(str+1);
	for(int i=1;i<=n;i++)
	{
		int x=id(str[i]);
		u=son1[u][x];
	}
	return u;
}

int got2()
{
	int u=1,n=strlen(str+1);
	for(int i=1;i<=n;i++)
	{
		int x=id(str[i]);
		u=son2[u][x];
	}
	return u;
}

void addd(int pos,int val)
{
	for(int i=pos;i<=S-3;i+=i&-i) c[i]+=val;
}

int quee(int pos)
{
	int res=0;
	for(int i=pos;i>=1;i-=i&-i) res+=c[i];
	return res;
}

int main()
{
	scanf("%d%d",&n,&q);
	for(int i=1;i<=n;i++)
	{
		scanf("%s",str+1);
		ins1(i);
		int len=strlen(str+1);
		for(int j=1;j<=len/2;j++) swap(str[j],str[len-j+1]);
		ins2(i);
	}
	dfs1(1),dfs2(1);
	for(int i=1;i<=q;i++)
	{
		scanf("%s",str+1);
		que[i].x=got1();
		scanf("%s",str+1);
		int len=strlen(str+1);
		for(int j=1;j<=len/2;j++) swap(str[j],str[len-j+1]);
		que[i].y=got2();
	}
	for(int i=1;i<=q;i++)
	{
//		printf("[%d %d] [%d %d]\n",lx[que[i].x],rx[que[i].x],ly[que[i].y],ry[que[i].y]);
		if(que[i].x==0||que[i].y==0) continue;
		pts[++m]=(node2){rx[que[i].x],ry[que[i].y],i,1};
		pts[++m]=(node2){lx[que[i].x]-1,ry[que[i].y],i,-1};
		pts[++m]=(node2){rx[que[i].x],ly[que[i].y]-1,i,-1};
		pts[++m]=(node2){lx[que[i].x]-1,ly[que[i].y]-1,i,1};
	}
	sort(a+1,a+n+1,[&](node x,node y){return x.x<y.x;});
	sort(pts+1,pts+m+1,[&](node2 x,node2 y){return x.x<y.x;});
//	for(int i=1;i<=n;i++) printf("%d %d\n",a[i].x,a[i].y);
	for(int i=1,j=1;i<=m;i++)
	{
		while(j<=n&&a[j].x<=pts[i].x) addd(a[j++].y,1);
		int pre=quee(pts[i].y);
		res[pts[i].id]+=pre*pts[i].tpe;
	}
	for(int i=1;i<=q;i++) printf("%d\n",res[i]);
	return 0;
}