Treap 学习笔记

Treap 是平衡树的一种,也是一棵笛卡尔树它是一种随机数据结构,相当于 wiw_i 是随机数的笛卡尔树

模板题

首先对于一个节点,我们不仅需要保存二叉搜索树的几个值,还要保存 wiw_i。节点结构体定义如下:

struct node
{
    int val,w; // 数值,权值 w 
    int cnt,sum; // 数值个数,子树大小 
    int l,r; // 左右儿子 
}tree[100005];

朴素的笛卡尔树并没有插入和删除的操作,而作为平衡树,Treap 必须支持这两种操作。但是插入和删除的过程中还要维护堆的特性,需要进行旋转操作,即二叉堆中的把某个儿子提上来。

由于旋转时还需要维护二叉搜索树的特性,所以旋转稍微有点复杂:

这样旋转不但可以做到把某个儿子提上来,还可以维护二叉搜索树的特性。

左旋右旋代码如下:

inline void upd(int u)
{
    tree[u].sum=tree[tree[u].l].sum+tree[tree[u].r].sum+tree[u].cnt;
}

inline void lrot(int &u) // 左旋 
{
	int t=tree[u].r; // 先存下 Y 的编号
    tree[u].r=tree[t].l; // 右儿子变成 B 
    tree[t].l=u; // Y 的左儿子变成 X 
    tree[t].sum=tree[u].sum; // Y 变成了之前 X 的子树的根 
    upd(u); // 更新 X 的子树大小 
    u=t; // 完成旋转 
}

inline void rrot(int &u) // 右旋,和左旋原理一样 
{
	int t=tree[u].l;
    tree[u].l=tree[t].r;
    tree[t].r=u;
    tree[t].sum=tree[u].sum;
    upd(u);
    u=t;
}

有了左旋右旋,插入和删除就不难实现了。

插入:

void ins(int &u,int val) // 插入 
{
    if(u==0) // 如果递归到了空节点,那么在这里插入 
    {
        u=++cnt;
        tree[u].val=val;
        tree[u].cnt=1;
        tree[u].sum=1;
        tree[u].w=rand(); // 随机化权值 w 
        return;
    }
    tree[u].sum++; // 子树大小 ++ 
    if(val==tree[u].val) tree[u].cnt++; // 如果找到了值为 val 的节点,那么该节点的数值个数 ++ 
    else if(val<tree[u].val)
    {
        ins(tree[u].l,val); // 往左子树插入 
        if(tree[tree[u].l].w<tree[u].w) rrot(u); // 维护二叉堆的性质 
    }
    else
    {
        ins(tree[u].r,val); // 往右子树插入 
        if(tree[tree[u].r].w<tree[u].w) lrot(u); // 维护二叉堆的性质 
    }
}

删除:(返回值为成不成功,即有没有找到值为 val 的节点)

bool del(int &u,int val) // 删除 
{
    if(u==0) return false; // 递归到了空节点,删除失败 
    if(val==tree[u].val) // 找到了 
    {
        if(tree[u].cnt>1) // 删除之后节点还存在 
        {
        	tree[u].sum--;
            tree[u].cnt--;
            return true; // 删除成功 
        }
        else
        {
        	if(tree[u].l==0||tree[u].r==0) // 只有一个儿子或者没有儿子,那么直接用儿子替换当前节点 
			{
				u=tree[u].l+tree[u].r;
				return true; // 删除成功 
			}
            else // 有两个儿子,那么我们可以把当前节点通过左旋右旋往下移动,直到可以直接删除 
            {
                if(tree[tree[u].l].w<tree[tree[u].r].w) // 需要提左儿子上来 
                {
                    rrot(u); // 右旋 
                    return del(u,val); // 递归,注意递归的节点必须是 u 
                }
                else // 需要提右儿子上来 
                {
                    lrot(u);
                    return del(u,val); // 递归,注意递归的节点必须是 u 
                }
            }
        }
    }
    else if(val<tree[u].val) 
	{
		bool f=del(tree[u].l,val); // 往左子树递归 
		if(f) tree[u].sum--; // 如果删除成功,那么子树大小 -- 
		return f;
	}
	else
	{
		bool f=del(tree[u].r,val); // 往右子树递归 
		if(f) tree[u].sum--; // 如果删除成功,那么子树大小 -- 
		return f;
	}
}

解决了插入删除两个比较困难的操作后,剩下的操作就非常好实现了。

求排名:

int getrk(int u,int val)
{
    if(u==0) return 1;
    if(tree[u].val==val) return tree[tree[u].l].sum+1;
    else if(val<tree[u].val) return getrk(tree[u].l,val);
    else return tree[tree[u].l].sum+tree[u].cnt+getrk(tree[u].r,val);
}

求排名为 val 的数:

int getbyrk(int u,int val)
{
    if(u==0) return 0;
    if(val<=tree[tree[u].l].sum) return getbyrk(tree[u].l,val);
    else if(val<=tree[tree[u].l].sum+tree[u].cnt) return tree[u].val;
    else return getbyrk(tree[u].r,val-tree[tree[u].l].sum-tree[u].cnt);
}

求前驱后继:

int getfrt(int u,int val)
{
    if(u==0) return -inf;
    if(val<=tree[u].val) return getfrt(tree[u].l,val);
    else return max(tree[u].val,getfrt(tree[u].r,val));
}

int getnxt(int u,int val)
{
    if(u==0) return inf;
    if(val>=tree[u].val) return getnxt(tree[u].r,val);
    else return min(tree[u].val,getnxt(tree[u].l,val));
}

完整代码:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <ctime>

using namespace std;

#define inf (((long long)1<<31)-1)

struct node
{
    int val,w; // 数值,权值 w 
    int cnt,sum; // 数值个数,子树大小 
    int l,r; // 左右儿子 
}tree[100005];

int q,cnt,rt;

inline void upd(int u)
{
    tree[u].sum=tree[tree[u].l].sum+tree[tree[u].r].sum+tree[u].cnt;
}

inline void lrot(int &u) // 左旋 
{
	int t=tree[u].r; // 先存下 Y 的编号
    tree[u].r=tree[t].l; // 右儿子变成 B 
    tree[t].l=u; // Y 的左儿子变成 X 
    tree[t].sum=tree[u].sum; // Y 变成了之前 X 的子树的根 
    upd(u); // 更新 X 的子树大小 
    u=t; // 完成旋转 
}

inline void rrot(int &u) // 右旋,和左旋原理一样 
{
	int t=tree[u].l;
    tree[u].l=tree[t].r;
    tree[t].r=u;
    tree[t].sum=tree[u].sum;
    upd(u);
    u=t;
}

void ins(int &u,int val) // 插入 
{
    if(u==0) // 如果递归到了空节点,那么在这里插入 
    {
        u=++cnt;
        tree[u].val=val;
        tree[u].cnt=1;
        tree[u].sum=1;
        tree[u].w=rand(); // 随机化权值 w 
        return;
    }
    tree[u].sum++; // 子树大小 ++ 
    if(val==tree[u].val) tree[u].cnt++; // 如果找到了值为 val 的节点,那么该节点的数值个数 ++ 
    else if(val<tree[u].val)
    {
        ins(tree[u].l,val); // 往左子树插入 
        if(tree[tree[u].l].w<tree[u].w) rrot(u); // 维护二叉堆的性质 
    }
    else
    {
        ins(tree[u].r,val); // 往右子树插入 
        if(tree[tree[u].r].w<tree[u].w) lrot(u); // 维护二叉堆的性质 
    }
}

bool del(int &u,int val) // 删除 
{
    if(u==0) return false; // 递归到了空节点,删除失败 
    if(val==tree[u].val) // 找到了 
    {
        if(tree[u].cnt>1) // 删除之后节点还存在 
        {
        	tree[u].sum--;
            tree[u].cnt--;
            return true; // 删除成功 
        }
        else
        {
        	if(tree[u].l==0||tree[u].r==0) // 只有一个儿子或者没有儿子,那么直接用儿子替换当前节点 
			{
				u=tree[u].l+tree[u].r;
				return true; // 删除成功 
			}
            else // 有两个儿子,那么我们可以把当前节点通过左旋右旋往下移动,直到可以直接删除 
            {
                if(tree[tree[u].l].w<tree[tree[u].r].w) // 需要提左儿子上来 
                {
                    rrot(u); // 右旋 
                    return del(u,val); // 递归,注意递归的节点必须是 u 
                }
                else // 需要提右儿子上来 
                {
                    lrot(u);
                    return del(u,val); // 递归,注意递归的节点必须是 u 
                }
            }
        }
    }
    else if(val<tree[u].val) 
	{
		bool f=del(tree[u].l,val); // 往左子树递归 
		if(f) tree[u].sum--; // 如果删除成功,那么子树大小 -- 
		return f;
	}
	else
	{
		bool f=del(tree[u].r,val); // 往右子树递归 
		if(f) tree[u].sum--; // 如果删除成功,那么子树大小 -- 
		return f;
	}
}

int getrk(int u,int val)
{
    if(u==0) return 1;
    if(tree[u].val==val) return tree[tree[u].l].sum+1;
    else if(val<tree[u].val) return getrk(tree[u].l,val);
    else return tree[tree[u].l].sum+tree[u].cnt+getrk(tree[u].r,val);
}

int getbyrk(int u,int val)
{
    if(u==0) return 0;
    if(val<=tree[tree[u].l].sum) return getbyrk(tree[u].l,val);
    else if(val<=tree[tree[u].l].sum+tree[u].cnt) return tree[u].val;
    else return getbyrk(tree[u].r,val-tree[tree[u].l].sum-tree[u].cnt);
}

int getfrt(int u,int val)
{
    if(u==0) return -inf;
    if(val<=tree[u].val) return getfrt(tree[u].l,val);
    else return max(tree[u].val,getfrt(tree[u].r,val));
}

int getnxt(int u,int val)
{
    if(u==0) return inf;
    if(val>=tree[u].val) return getnxt(tree[u].r,val);
    else return min(tree[u].val,getnxt(tree[u].l,val));
}

int main()
{
  	srand(time(NULL));
    scanf("%d",&q);
    while(q--)
    {
        int opt,x;
        scanf("%d%d",&opt,&x);
        if(opt==1) ins(rt,x);
        if(opt==2) del(rt,x);
        if(opt==3) printf("%d\n",getrk(rt,x));
        if(opt==4) printf("%d\n",getbyrk(rt,x));
        if(opt==5) printf("%d\n",getfrt(rt,x));
        if(opt==6) printf("%d\n",getnxt(rt,x));
    }
    return 0;
}