【PR #15】二叉搜索树 做题记录

有一棵 nn 个点的树,每个点上有一个二叉搜索树。

然后依次进行 mm 次操作:

  • 1 u v x:在 uvu\to v 路径上每个点的二叉搜索树中插入元素 xx,即 p{uv}\forall p\in \{u\to v\} 执行 insert(rtp,x)\text{insert}(rt_p,x)

    void insert(int&p,int x){
        if(!p) return p=x,void();
        if(x<p) insert(ch[p][0],x);
        else insert(ch[p][1],x);
    }
    
  • 0 u w:在点 uu 上的二叉搜索树中执行 ask(rtu,w)\text{ask}(rt_u,w),求其返回值:

    long long ask(int p,int x){
        if(x==p) return x;
        if(x<p) return ch[p][0]?ask(ch[p][0])+p:p;
        else return ch[p][1]?ask(ch[p][1])+p:p;
    }
    

1n,m,x,w2×1051\le n,m,x,w\le 2\times 10^5,每次 11 操作的 xx 互不相同。

先考虑单点怎么做,直接维护二叉搜索树中每个点的父亲即可,一个点 xx 的父亲是其插入时前驱和后继中插入时间较晚的那个。

接下来考虑链,一个朴素的想法是差分,问题变为维护加入/删除某次操作后二叉搜索树的形态。

使用超强注意力,用二元组 (a,b)(a,b) 记录点 aa 的插入时间是 bb,那么注意到 a[1,x]a\in [1,x] 的点和 a[x,m]a\in [x,m] 的点是独立的。具体的,对于 a[1,x]a\in [1,x] 的点,可以这样判断它们中哪些是 (x,i)(x,i) 的祖先:

  • 将这些点按照 aa 从大到小排序;
  • t=it=i,依次遍历排好序后的每个点 (a,b)(a,b)
    • b<tb<t 则其为 (x,i)(x,i) 祖先,答案加上 aa,令 t=bt=b
    • 否则其非 (x,i)(x,i) 祖先;

对于 a[x,m]a\in [x,m] 的点是一样的,将从大到小排序改为从小到大排序即可。

那么问题转化为区间中 aa 的前缀/后缀最小值的 bb 的和,直接楼房重建线段树即可 O(nlog2n)O(n\log ^2n) 维护。

现在考虑树,不难发现和链是基本一样的,线段树合并即可。

时间复杂度 O(nlog2n)O(n\log ^2 n),代码如下:

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>

using namespace std;

typedef long long ll;

/*
insert/delete pot
query [l,r] premin sum or sufmin sum
*/
namespace seg
{
	const int TS=30000005,inf=1e8;
	int cnt,ls[TS],rs[TS];
	int mn[TS];
	ll sml[TS],smr[TS];
	inline void init(){mn[0]=inf;}
	ll quesml(int u,int l,int r,int k) // premin(<k) sum
	{
		if(u==0||mn[u]>=k) return 0;
		if(l==r) return mn[u]<k?sml[u]:0;
		int mid=l+r>>1;
		if(mn[ls[u]]<k) return quesml(ls[u],l,mid,k)+sml[u]-sml[ls[u]];
		else return quesml(rs[u],mid+1,r,k);
	}
	ll quesmr(int u,int l,int r,int k) // sufmin(<k) sum
	{
		if(u==0||mn[u]>=k) return 0;
		if(l==r) return mn[u]<k?smr[u]:0;
		int mid=l+r>>1;
		if(mn[rs[u]]<k) return smr[u]-smr[rs[u]]+quesmr(rs[u],mid+1,r,k);
		else return quesmr(ls[u],l,mid,k);
	}
	inline void upda(int u,int l,int r)
	{
		int mid=l+r>>1;
		int sl=ls[u],sr=rs[u];
		mn[u]=min(mn[sl],mn[sr]);
		sml[u]=sml[sl]+quesml(sr,mid+1,r,mn[sl]);
		smr[u]=quesmr(sl,l,mid,mn[sr])+smr[sr];
	}
	void updp(int &u,int l,int r,int p,int x,int val)
	{
		if(u==0) u=++cnt;
		if(l==r) return mn[u]=x,sml[u]=smr[u]=val,void();
		int mid=l+r>>1;
		if(p<=mid) updp(ls[u],l,mid,p,x,val);
		else updp(rs[u],mid+1,r,p,x,val);
		upda(u,l,r);
	}
	int merge(int x,int y,int l,int r)
	{
		if(x==0||y==0) return x+y;
		if(l==r)
		{
			mn[x]=min(mn[x],mn[y]);
			sml[x]=max(sml[x],sml[y]);
			smr[x]=sml[x];
			return x;
		}
		int mid=l+r>>1;
		ls[x]=merge(ls[x],ls[y],l,mid);
		rs[x]=merge(rs[x],rs[y],mid+1,r);
		upda(x,l,r);
		return x;
	}
	void quesmllr(int u,int l,int r,int L,int R,int &k,ll &res)
	{
		if(u==0||l>R||r<L) return;
		if(l>=L&&r<=R)
		{
			res+=quesml(u,l,r,k);
			k=min(k,mn[u]);
			return;
		}
		int mid=l+r>>1;
		if(L<=mid) quesmllr(ls[u],l,mid,L,R,k,res);
		if(R>=mid+1) quesmllr(rs[u],mid+1,r,L,R,k,res);
	}
	void quesmrlr(int u,int l,int r,int L,int R,int &k,ll &res)
	{
		if(u==0||l>R||r<L) return;
		if(l>=L&&r<=R)
		{
			res+=quesmr(u,l,r,k);
			k=min(k,mn[u]);
			return;
		}
		int mid=l+r>>1;
		if(R>=mid+1) quesmrlr(rs[u],mid+1,r,L,R,k,res);
		if(L<=mid) quesmrlr(ls[u],l,mid,L,R,k,res);
	}
}

const int S=200005,BS=25;

int n,m;
vector<int> g[S];
int dep[S],fat[S][BS];
vector<pair<int,int> > vec[S];
int tot,b[S],val[S],rt[S];
ll ans[S];

void dfs(int u,int fa)
{
	dep[u]=dep[fa]+1;
	fat[u][0]=fa;
	for(int i=1;i<=BS-3;i++) fat[u][i]=fat[fat[u][i-1]][i-1];
	for(int v:g[u]) if(v!=fa) dfs(v,u);
}

inline int quelca(int x,int y)
{
	if(dep[x]<dep[y]) swap(x,y);
	for(int i=BS-3;i>=0;i--) if(dep[fat[x][i]]>=dep[y]) x=fat[x][i];
	if(x==y) return x;
	for(int i=BS-3;i>=0;i--)
	{
		if(fat[x][i]!=fat[y][i])
		{
			x=fat[x][i];
			y=fat[y][i];
		}
	}
	return fat[x][0];
}

void dfs2(int u,int fa)
{
	for(int v:g[u])
	{
		if(v==fa) continue;
		dfs2(v,u);
		rt[u]=seg::merge(rt[u],rt[v],1,tot);
	}
	for(auto t:vec[u])
	{
		int x=t.first,y=t.second;
		if(y==-1)
		{
			int k=x;
			seg::quesmrlr(rt[u],1,tot,1,val[x],k,ans[x]);
			k=x;
			seg::quesmllr(rt[u],1,tot,val[x],tot,k,ans[x]);
			k=x;
			ll del=0;
			seg::quesmllr(rt[u],1,tot,val[x],val[x],k,del);
			ans[x]-=del;
		}
		else
		{
			if(x==1) seg::updp(rt[u],1,tot,val[y],y,b[val[y]]);
			else seg::updp(rt[u],1,tot,val[y],seg::inf,0);
		}
	}
}

int main()
{
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>m;
	for(int i=1;i<=n-1;i++)
	{
		int x,y;
		cin>>x>>y;
		g[x].push_back(y);
		g[y].push_back(x);
	}
	dfs(1,0);
	vector<int> queid;
	for(int i=1;i<=m;i++)
	{
		int op,x,y;
		cin>>op>>x>>y;
		if(op==1)
		{
			int w;
			cin>>w;
			int l=quelca(x,y);
			vec[x].emplace_back(1,i);
			vec[y].emplace_back(1,i);
			vec[fat[l][0]].emplace_back(-1,i);
			b[++tot]=val[i]=w;
		}
		else
		{
			vec[x].emplace_back(i,-1);
			b[++tot]=val[i]=y;
			queid.push_back(i);
		}
	}
	sort(b+1,b+tot+1);
	tot=unique(b+1,b+tot+1)-b-1;
	for(int i=1;i<=m;i++) val[i]=lower_bound(b+1,b+tot+1,val[i])-b;
	seg::init();
	dfs2(1,0);
	for(int x:queid) cout<<ans[x]<<'\n';
	return 0;
}