fhq-Treap 学习笔记

顾名思义,fhq-Treap 就是由 fhq 大佬改进的一种 Treap。它的核心思想是不使用旋转操作,而是使用分裂和合并操作来同时维护二叉搜索树的特性和二叉堆的特性

首先是分裂操作,顾名思义,就是把一棵 fhq-Treap 分裂成两颗 fhq-Treap。分裂有两种方式,一种是按照权值分裂,还有一种是按照大小分裂。而前者更常用,后者则通常是用来维护区间的。

按照权值分裂的代码如下:

void split(int u,int val,int &x,int &y) // 按权值分裂以 u 为根的树(x 子树上的所有权值都 < val,y 子树上的所有权值都 >= val) 
{
	if(u==0) // 如果没得分裂了 
	{
		x=y=0; // 赋值为 0 
		return;
	}
	if(tree[u].val<val) // 如果 u 的权值比 val 小,那么 u 属于 x 的子树 
	{
		x=u; // 赋值 
		split(tree[u].r,val,tree[u].r,y); // 递归分裂当前 u 的右儿子(u 的左儿子都比 val 小,都属于 x 的子树) 
	}
	else // 否则 u 属于 y 的子树 
	{
		y=u; // 赋值 
		split(tree[u].l,val,x,tree[u].l); // 递归分裂当前 u 的左儿子(u 的右儿子都 >= val,都属于 y 的子树) 
	}
	upd(u); // 不要忘记更新节点信息 
}

按大小分裂的代码如下:

void split(int u,int val,int &x,int &y)
{
	if(u==0)
	{
		x=y=0;
		return;
	}
	if(tree[tree[u].lson].sum<val)
	{
		x=u;
		split(tree[u].rson,val-tree[tree[u].lson].sum-1,tree[u].rson,y);
	}
	else
	{
		y=u;
		split(tree[u].lson,val,x,tree[u].lson);
	}
	upd(u);
}

然后是合并操作,即把两棵树合并成一棵树。合并的时候注意要同时满足二叉搜索树的特性和二叉堆的的特性。

int merge(int x,int y) // 合并 x 的子树和 y 的子树,返回合并之后的根(x 的子树的权值必须保证都小于 y 的子树的权值) 
{
	if(x==0||y==0) return x+y; // 有一棵子树为空,那么返回另一棵子树的根 
	if(tree[x].w>tree[y].w) // 我们要保证 w 满足大根堆的特性,所以如果 x 的 w 大于 y 的 w,那么就要把 y 接到 x 下面,
	// 又因为 x 的子树的权值都小于 y 的子树的权值,所以 y 要接到 x 的右儿子,即让 x 的右儿子和 y 合并 
	{
		tree[x].r=merge(tree[x].r,y);
		upd(x); // 记得更新节点信息 
		return x;
	}
	else // 同理 
	{
		tree[y].l=merge(x,tree[y].l); 
		upd(y); // 记得更新节点信息 
		return y;
	}
}

解决了这两个最核心也最难的操作后,剩下的操作就简单了。

插入操作就相当于把整棵树扒开,然后再把要插入的节点放进去,最后再缝合起来。这种十分暴力的行为很好写,而且还不容易出错,代码如下:

inline void ins(int val) // 插入 
{
	int x,y;
	split(rt,val,x,y); // 扒开整棵树 
	int u=++cnt; // 新建一个节点 
	tree[u].val=val;
	tree[u].sum=1;
	tree[u].w=rand();
	rt=merge(x,merge(u,y)); // 放进去,缝合起来 
}

删除操作则相当于扒开整棵树,找到要删除的子树的根,合并它的左右儿子,再把整棵树缝合起来。同样很暴力,代码如下:

inline void del(int val) // 删除 
{
	int x,y,z;
	split(rt,val,x,y); // 先把整棵树分为权值都小于 val 的子树 x 和权值都大于等于 val 的子树 y 
	split(y,val+1,y,z); // 再把 y 的子树里权值大于 val 的子树分割出来,此时 y 子树内的权值都等于 val 
	if(y!=0) y=merge(tree[y].l,tree[y].r); // 如果存在权值为 val 的节点,即 y!=0,则合并 y 的左右儿子 
	rt=merge(x,merge(y,z)); // 缝合整棵树 
}

查询 val 的排名则相当于把扒开整棵树,然后记录下权值小于 val 的子树的大小 + 1,再缝合整棵树,最后返回之前记录下的答案。代码如下:

inline int getrk(int val) // 求 val 的排名 
{
	int x,y;
	split(rt,val,x,y); // 扒开整棵树 
	int res=tree[x].sum+1; // 记录下权值都小于 val 的子树的大小 + 1(答案) 
	rt=merge(x,y); // 缝合整棵树 
	return res; // 返回答案 
}

查询排名为 val 的数是唯一一个没有那么暴力的操作,它是直接在 fhq-Treap 那布满伤痕的身体上进行遍历来求答案的:

inline int getbyrk(int val) // 获取排名为 val 的数 
{
	int u=rt; // 从根开始遍历 
	while(1)
	{
		if(tree[tree[u].l].sum+1==val) break; // 如果左子树大小 + 1 == val,那么答案就是当前节点的值 
		else if(tree[tree[u].l].sum+1>val) u=tree[u].l; // 如果左子树大小 + 1 > val,那么答案在当前节点的左子树 
		else // 否则答案在当前右子树 
		{
			val-=tree[tree[u].l].sum+1; // 记得把 val 减去左子树大小 + 1 
			u=tree[u].r; 
		}
	}
	return tree[u].val; // 返回答案 
}

求前驱后继也相当于把整棵树扒开,然后再处理,最后缝合上

inline int getfrt(int val) // 求前驱 
{
	int x,y;
	split(rt,val,x,y); // 把整棵树扒开,前驱肯定在 x 的子树内 
	int u=x;
	while(1) // 由于前驱是 x 子树内最大的那个值,所以要一直往右儿子去 
	{
		if(tree[u].r!=0) u=tree[u].r;
		else break;
	}
	rt=merge(x,y); // 缝合好 
	return tree[u].val; // 返回答案 
}

inline int getnxt(int val) // 求后继 
{
	int x,y;
	split(rt,val+1,x,y); // 把整棵树扒开,后继肯定在 y 的子树内 
	int u=y;
	while(1) // 由于后继是 y 子树内最小的那个值,所以要一直往左儿子去 
	{
		if(tree[u].l!=0) u=tree[u].l;
		else break;
	}
	rt=merge(x,y); // 缝合好 
	return tree[u].val; // 返回答案 
}

最终模板题代码如下:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <ctime>

using namespace std;

struct node
{
	int val,w;
	int sum;
	int l,r;
}tree[100005];

int q;
int cnt,rt;

inline void upd(int u)
{
	tree[u].sum=tree[tree[u].l].sum+tree[tree[u].r].sum+1;
}

void split(int u,int val,int &x,int &y) // 按权值分裂以 u 为根的树(x 子树上的所有权值都 < val,y 子树上的所有权值都 >= val) 
{
	if(u==0) // 如果没得分裂了 
	{
		x=y=0; // 赋值为 0 
		return;
	}
	if(tree[u].val<val) // 如果 u 的权值比 val 小,那么 u 属于 x 的子树 
	{
		x=u; // 赋值 
		split(tree[u].r,val,tree[u].r,y); // 递归分裂当前 u 的右儿子(u 的左儿子都比 val 小,都属于 x 的子树) 
	}
	else // 否则 u 属于 y 的子树 
	{
		y=u; // 赋值 
		split(tree[u].l,val,x,tree[u].l); // 递归分裂当前 u 的左儿子(u 的右儿子都 >= val,都属于 y 的子树) 
	}
	upd(u); // 不要忘记更新节点信息 
}

int merge(int x,int y) // 合并 x 的子树和 y 的子树,返回合并之后的根(x 的子树的权值必须保证都小于 y 的子树的权值) 
{
	if(x==0||y==0) return x+y; // 有一棵子树为空,那么返回另一棵子树的根 
	if(tree[x].w>tree[y].w) // 我们要保证 w 满足大根堆的特性,所以如果 x 的 w 大于 y 的 w,那么就要把 y 接到 x 下面,
	// 又因为 x 的子树的权值都小于 y 的子树的权值,所以 y 要接到 x 的右儿子,即让 x 的右儿子和 y 合并 
	{
		tree[x].r=merge(tree[x].r,y);
		upd(x); // 记得更新节点信息 
		return x;
	}
	else // 同理 
	{
		tree[y].l=merge(x,tree[y].l); 
		upd(y); // 记得更新节点信息 
		return y;
	}
}

inline void ins(int val) // 插入 
{
	int x,y;
	split(rt,val,x,y); // 扒开整棵树 
	int u=++cnt; // 新建一个节点 
	tree[u].val=val;
	tree[u].sum=1;
	tree[u].w=rand();
	rt=merge(x,merge(u,y)); // 放进去,缝合起来 
}

inline void del(int val) // 删除 
{
	int x,y,z;
	split(rt,val,x,y); // 先把整棵树分为权值都小于 val 的子树 x 和权值都大于等于 val 的子树 y 
	split(y,val+1,y,z); // 再把 y 的子树里权值大于 val 的子树分割出来,此时 y 子树内的权值都等于 val 
	if(y!=0) y=merge(tree[y].l,tree[y].r); // 如果存在权值为 val 的节点,即 y!=0,则合并 y 的左右儿子 
	rt=merge(x,merge(y,z)); // 缝合整棵树 
}

inline int getrk(int val) // 求 val 的排名 
{
	int x,y;
	split(rt,val,x,y); // 扒开整棵树 
	int res=tree[x].sum+1; // 记录下权值都小于 val 的子树的大小 + 1(答案) 
	rt=merge(x,y); // 缝合整棵树 
	return res; // 返回答案 
}

inline int getbyrk(int val) // 获取排名为 val 的数 
{
	int u=rt; // 从根开始遍历 
	while(1)
	{
		if(tree[tree[u].l].sum+1==val) break; // 如果左子树大小 + 1 == val,那么答案就是当前节点的值 
		else if(tree[tree[u].l].sum+1>val) u=tree[u].l; // 如果左子树大小 + 1 > val,那么答案在当前节点的左子树 
		else // 否则答案在当前右子树 
		{
			val-=tree[tree[u].l].sum+1; // 记得把 val 减去左子树大小 + 1 
			u=tree[u].r; 
		}
	}
	return tree[u].val; // 返回答案 
}

inline int getfrt(int val) // 求前驱 
{
	int x,y;
	split(rt,val,x,y); // 把整棵树扒开,前驱肯定在 x 的子树内 
	int u=x;
	while(1) // 由于前驱是 x 子树内最大的那个值,所以要一直往右儿子去 
	{
		if(tree[u].r!=0) u=tree[u].r;
		else break;
	}
	rt=merge(x,y); // 缝合好 
	return tree[u].val; // 返回答案 
}

inline int getnxt(int val) // 求后继 
{
	int x,y;
	split(rt,val+1,x,y); // 把整棵树扒开,后继肯定在 y 的子树内 
	int u=y;
	while(1) // 由于后继是 y 子树内最小的那个值,所以要一直往左儿子去 
	{
		if(tree[u].l!=0) u=tree[u].l;
		else break;
	}
	rt=merge(x,y); // 缝合好 
	return tree[u].val; // 返回答案 
}

int main()
{
	srand(time(NULL));
	scanf("%d",&q);
	while(q--)
	{
		int opt,x;
		scanf("%d%d",&opt,&x);
		if(opt==1) ins(x);
		if(opt==2) del(x);
		if(opt==3) printf("%d\n",getrk(x));
		if(opt==4) printf("%d\n",getbyrk(x));
		if(opt==5) printf("%d\n",getfrt(x));
		if(opt==6) printf("%d\n",getnxt(x));
	}
	return 0;
}