树链剖分学习笔记

树链剖分,就是用来增加代码长度的……

反正我模板题写了 4 kb/ll

树链剖分,简称树剖,是用来把一棵树划分成很多条互不相交的链,再用类似线段树的数据结构来维护信息的。通常,“树剖”指的是树链剖分中的重链剖分。

模板题

首先看一些定义:

  • 重儿子:对于每一个非叶节点,它所有儿子中子树最大的节点

  • 轻儿子:不是重儿子的其他儿子

  • 重边:一个非叶节点连向它的重儿子的边

  • 轻边:一个非叶节点连向它的轻儿子的边

  • 重链:以一个轻儿子开始,由若干条重边连接而成的链(单独一个轻儿子也能作为一条重链

很容易发现,每条重链都不会相交,这样我们就可以使用类似线段树这样的数据结构来维护这些重链了

然后我们很显然可以在 O(n)O(n) 的时间复杂度内求出节点 uu 的父节点 fatufat_u、子树大小 sizusiz_u 和重儿子 hsonuhson_u

代码如下:

void dfs1(int u,int fa)
{
	fat[u]=fa;
	siz[u]=1;
	dep[u]=dep[fa]+1;
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa)
		{
			continue;
		}
		dfs1(v,u);
		siz[u]+=siz[v];
		if(siz[v]>siz[hson[u]])
		{
			hson[u]=v;
		}
	}
}

接下来,我们考虑将整棵树映射到一个序列上,再使用线段树去维护它。很显然,可以使用 dfs 序,然后只要先递归重儿子,就可以保证同一条重链的区间是连续的了

具体的做法是,再跑一遍 dfs,在 O(n)O(n) 的时间复杂度内求出新的序列 aa、节点 uu 在新序列中的编号 iduid_u、节点 uu 所属链的开头的节点 toputop_u 和节点 uu 子树在新序列中的右端点 RuR_u

代码如下:

void dfs2(int u,int tpf,int fa) // tpf 是 u 所属的重链的开头节点 
{
	id[u]=++cnt;
	a[cnt]=val[u];
	top[u]=tpf;
	if(hson[u]==0) // 如果没有重儿子说明是叶子 
	{
		R[u]=cnt;
		return;
	}
	dfs2(hson[u],tpf,u); // 先递归重儿子,因为它和 u 属于同一条重链 
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa||v==hson[u])
		{
			continue;
		}
		dfs2(v,v,u); // 对于轻儿子,直接新开一条重链 
	}
	R[u]=cnt;
}

现在我们已经成功地把树“剖开”了,接下来的工作就是使用线段树来维护新的序列。

首先看两个对于 xx 的子树的操作。很显然,xx 的子树对应的区间是 [idx,Rx][id_x,R_x],那么子树加和子树和查询就相当于在这个区间上操作。

代码如下:

inline void addsubtree(int u,long long k)
{
	upd(1,1,n,id[u],R[u],k);
}

inline long long quesubtree(int u)
{
	return que(1,1,n,id[u],R[u]);
}

而对于两个节点 x,yx,y 的最短路径的操作,可以让它们不断往上跳重链,直到两点在同一重链里

具体的做法是,不断令所在重链开始节点的深度较大的那个节点往上跳一整条链,并计算这条链的贡献

画出来是这样的:

代码如下:

inline void addpath(int x,int y,long long k)
{
	while(top[x]!=top[y])
	{
		if(dep[top[x]]>dep[top[y]])
		{
			upd(1,1,n,id[top[x]],id[x],k); // 贡献 
			x=fat[top[x]]; // 往上跳 
		}
		else
		{
			upd(1,1,n,id[top[y]],id[y],k); // 贡献 
			y=fat[top[y]]; // 往上跳 
		}
	}
	upd(1,1,n,min(id[x],id[y]),max(id[x],id[y]),k); // 最后的一小段 
}

inline long long quepath(int x,int y) // 同理 
{
	long long res=0;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]>dep[top[y]])
		{
			res=(res+que(1,1,n,id[top[x]],id[x]))%p;
			x=fat[top[x]];
		}
		else
		{
			res=(res+que(1,1,n,id[top[y]],id[y]))%p;
			y=fat[top[y]];
		}
	}
	res=(res+que(1,1,n,min(id[x],id[y]),max(id[x],id[y])))%p;
	return res;
}

复杂度证明:

首先子树操作的复杂度显然是 O(logn)O(\log n)

然后对于最短路径操作,由于 uu 的轻儿子的子树大小最多是 sizu2\lfloor\dfrac{siz_u}{2}\rfloor,所以最多只会跳 logn\log n 次重边,那么时间复杂度就是 O(log2n)O(\log^2n)

综上,树剖的时间复杂度为 O(log2n)O(\log^2n)

模板题代码如下:

#include <iostream>
#include <cstdio>

using namespace std;

const long long S=200005,MS=100005;

int n,m,r;
long long p,val[MS];
int esum,to[S],nxt[S],h[MS];
int fat[MS],siz[MS],hson[MS],dep[MS];
int cnt,id[MS],top[MS],R[MS];
long long a[MS];
long long sum[MS<<2],lazy[MS<<2];

inline void add(int x,int y)
{
	to[++esum]=y;
	nxt[esum]=h[x];
	h[x]=esum;
}

void dfs1(int u,int fa)
{
	fat[u]=fa;
	siz[u]=1;
	dep[u]=dep[fa]+1;
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa)
		{
			continue;
		}
		dfs1(v,u);
		siz[u]+=siz[v];
		if(siz[v]>siz[hson[u]])
		{
			hson[u]=v;
		}
	}
}

void dfs2(int u,int tpf,int fa) // tpf 是 u 所属的重链的开头节点 
{
	id[u]=++cnt;
	a[cnt]=val[u];
	top[u]=tpf;
	if(hson[u]==0) // 如果没有重儿子说明是叶子 
	{
		R[u]=cnt;
		return;
	}
	dfs2(hson[u],tpf,u); // 先递归重儿子,因为它和 u 属于同一条重链 
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa||v==hson[u])
		{
			continue;
		}
		dfs2(v,v,u); // 对于轻儿子,直接新开一条重链 
	}
	R[u]=cnt;
}

inline void updata(int u)
{
	sum[u]=(sum[u<<1]+sum[u<<1|1])%p;
}

inline void lazydown(int u,int l,int r)
{
	if(lazy[u]==0)
	{
		return;
	}
	int mid=l+r>>1;
	sum[u<<1]=(sum[u<<1]+lazy[u]*(mid-l+1))%p;
	sum[u<<1|1]=(sum[u<<1|1]+lazy[u]*(r-mid))%p;
	lazy[u<<1]=(lazy[u<<1]+lazy[u])%p;
	lazy[u<<1|1]=(lazy[u<<1|1]+lazy[u])%p;
	lazy[u]=0;
}

void build(int u,int l,int r)
{
	if(l==r)
	{
		sum[u]=a[l];
		return;
	}
	int mid=l+r>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
	updata(u);
}

void upd(int u,int l,int r,int L,int R,long long k)
{
	if(r<L||l>R)
	{
		return;
	}
	if(l>=L&&r<=R)
	{
		sum[u]=(sum[u]+k*(r-l+1))%p;
		lazy[u]=(lazy[u]+k)%p;
		return;
	}
	lazydown(u,l,r);
	int mid=l+r>>1;
	if(L<=mid)
	{
		upd(u<<1,l,mid,L,R,k);
	}
	if(r>=mid+1)
	{
		upd(u<<1|1,mid+1,r,L,R,k);
	}
	updata(u);
}

long long que(int u,int l,int r,int L,int R)
{
	if(r<L||l>R)
	{
		return 0;
	}
	if(l>=L&&r<=R)
	{
		return sum[u];
	}
	lazydown(u,l,r);
	int mid=l+r>>1;
	long long res=0;
	if(L<=mid)
	{
		res=(res+que(u<<1,l,mid,L,R))%p;
	}
	if(R>=mid+1)
	{
		res=(res+que(u<<1|1,mid+1,r,L,R))%p;
	}
	return res;
}

inline void addpath(int x,int y,long long k)
{
	while(top[x]!=top[y])
	{
		if(dep[top[x]]>dep[top[y]])
		{
			upd(1,1,n,id[top[x]],id[x],k); // 贡献 
			x=fat[top[x]]; // 往上跳 
		}
		else
		{
			upd(1,1,n,id[top[y]],id[y],k); // 贡献 
			y=fat[top[y]]; // 往上跳 
		}
	}
	upd(1,1,n,min(id[x],id[y]),max(id[x],id[y]),k); // 最后的一小段 
}

inline long long quepath(int x,int y) // 同理 
{
	long long res=0;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]>dep[top[y]])
		{
			res=(res+que(1,1,n,id[top[x]],id[x]))%p;
			x=fat[top[x]];
		}
		else
		{
			res=(res+que(1,1,n,id[top[y]],id[y]))%p;
			y=fat[top[y]];
		}
	}
	res=(res+que(1,1,n,min(id[x],id[y]),max(id[x],id[y])))%p;
	return res;
}

inline void addsubtree(int u,long long k)
{
	upd(1,1,n,id[u],R[u],k);
}

inline long long quesubtree(int u)
{
	return que(1,1,n,id[u],R[u]);
}

int main()
{
	scanf("%d%d%d%lld",&n,&m,&r,&p);
	for(int i=1;i<=n;i++)
	{
		scanf("%lld",&val[i]);
		val[i]%=p;
	}
	for(int i=1;i<=n-1;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(r,0);
	dfs2(r,r,0);
	build(1,1,n);
	while(m--)
	{
		int ty;
		scanf("%d",&ty);
		if(ty==1)
		{
			int x,y;
			long long z;
			scanf("%d%d%lld",&x,&y,&z);
			addpath(x,y,z);
		}
		else if(ty==2)
		{
			int x,y;
			scanf("%d%d",&x,&y);
			printf("%lld\n",quepath(x,y));
		}
		else if(ty==3)
		{
			int x;
			long long z;
			scanf("%d%lld",&x,&z);
			addsubtree(x,z);
		}
		else
		{
			int x;
			scanf("%d",&x);
			printf("%lld\n",quesubtree(x));
		}
	}
	return 0;
}

练习题目

P2146 [NOI2015] 软件包管理器

P2486 [SDOI2011]染色

P2590 [ZJOI2008]树的统计

P3178 [HAOI2015]树上操作