树链剖分,就是用来增加代码长度的……
反正我模板题写了 4 kb/ll
树链剖分,简称树剖,是用来把一棵树划分成很多条互不相交的链,再用类似线段树的数据结构来维护信息的。通常,“树剖”指的是树链剖分中的重链剖分。
首先看一些定义:
-
重儿子:对于每一个非叶节点,它所有儿子中子树最大的节点
-
轻儿子:不是重儿子的其他儿子
-
重边:一个非叶节点连向它的重儿子的边
-
轻边:一个非叶节点连向它的轻儿子的边
-
重链:以一个轻儿子开始,由若干条重边连接而成的链(单独一个轻儿子也能作为一条重链)

很容易发现,每条重链都不会相交,这样我们就可以使用类似线段树这样的数据结构来维护这些重链了。
然后我们很显然可以在 的时间复杂度内求出节点 的父节点 、子树大小 和重儿子 。
代码如下:
void dfs1(int u,int fa)
{
fat[u]=fa;
siz[u]=1;
dep[u]=dep[fa]+1;
for(int i=h[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa)
{
continue;
}
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[hson[u]])
{
hson[u]=v;
}
}
}
接下来,我们考虑将整棵树映射到一个序列上,再使用线段树去维护它。很显然,可以使用 dfs 序,然后只要先递归重儿子,就可以保证同一条重链的区间是连续的了。
具体的做法是,再跑一遍 dfs,在 的时间复杂度内求出新的序列 、节点 在新序列中的编号 、节点 所属链的开头的节点 和节点 子树在新序列中的右端点 。
代码如下:
void dfs2(int u,int tpf,int fa) // tpf 是 u 所属的重链的开头节点
{
id[u]=++cnt;
a[cnt]=val[u];
top[u]=tpf;
if(hson[u]==0) // 如果没有重儿子说明是叶子
{
R[u]=cnt;
return;
}
dfs2(hson[u],tpf,u); // 先递归重儿子,因为它和 u 属于同一条重链
for(int i=h[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa||v==hson[u])
{
continue;
}
dfs2(v,v,u); // 对于轻儿子,直接新开一条重链
}
R[u]=cnt;
}
现在我们已经成功地把树“剖开”了,接下来的工作就是使用线段树来维护新的序列。
首先看两个对于 的子树的操作。很显然, 的子树对应的区间是 ,那么子树加和子树和查询就相当于在这个区间上操作。
代码如下:
inline void addsubtree(int u,long long k)
{
upd(1,1,n,id[u],R[u],k);
}
inline long long quesubtree(int u)
{
return que(1,1,n,id[u],R[u]);
}
而对于两个节点 的最短路径的操作,可以让它们不断往上跳重链,直到两点在同一重链里。
具体的做法是,不断令所在重链开始节点的深度较大的那个节点往上跳一整条链,并计算这条链的贡献。
画出来是这样的:

代码如下:
inline void addpath(int x,int y,long long k)
{
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])
{
upd(1,1,n,id[top[x]],id[x],k); // 贡献
x=fat[top[x]]; // 往上跳
}
else
{
upd(1,1,n,id[top[y]],id[y],k); // 贡献
y=fat[top[y]]; // 往上跳
}
}
upd(1,1,n,min(id[x],id[y]),max(id[x],id[y]),k); // 最后的一小段
}
inline long long quepath(int x,int y) // 同理
{
long long res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])
{
res=(res+que(1,1,n,id[top[x]],id[x]))%p;
x=fat[top[x]];
}
else
{
res=(res+que(1,1,n,id[top[y]],id[y]))%p;
y=fat[top[y]];
}
}
res=(res+que(1,1,n,min(id[x],id[y]),max(id[x],id[y])))%p;
return res;
}
复杂度证明:
首先子树操作的复杂度显然是 。
然后对于最短路径操作,由于 的轻儿子的子树大小最多是 ,所以最多只会跳 次重边,那么时间复杂度就是 。
综上,树剖的时间复杂度为 。
模板题代码如下:
#include <iostream>
#include <cstdio>
using namespace std;
const long long S=200005,MS=100005;
int n,m,r;
long long p,val[MS];
int esum,to[S],nxt[S],h[MS];
int fat[MS],siz[MS],hson[MS],dep[MS];
int cnt,id[MS],top[MS],R[MS];
long long a[MS];
long long sum[MS<<2],lazy[MS<<2];
inline void add(int x,int y)
{
to[++esum]=y;
nxt[esum]=h[x];
h[x]=esum;
}
void dfs1(int u,int fa)
{
fat[u]=fa;
siz[u]=1;
dep[u]=dep[fa]+1;
for(int i=h[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa)
{
continue;
}
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[hson[u]])
{
hson[u]=v;
}
}
}
void dfs2(int u,int tpf,int fa) // tpf 是 u 所属的重链的开头节点
{
id[u]=++cnt;
a[cnt]=val[u];
top[u]=tpf;
if(hson[u]==0) // 如果没有重儿子说明是叶子
{
R[u]=cnt;
return;
}
dfs2(hson[u],tpf,u); // 先递归重儿子,因为它和 u 属于同一条重链
for(int i=h[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa||v==hson[u])
{
continue;
}
dfs2(v,v,u); // 对于轻儿子,直接新开一条重链
}
R[u]=cnt;
}
inline void updata(int u)
{
sum[u]=(sum[u<<1]+sum[u<<1|1])%p;
}
inline void lazydown(int u,int l,int r)
{
if(lazy[u]==0)
{
return;
}
int mid=l+r>>1;
sum[u<<1]=(sum[u<<1]+lazy[u]*(mid-l+1))%p;
sum[u<<1|1]=(sum[u<<1|1]+lazy[u]*(r-mid))%p;
lazy[u<<1]=(lazy[u<<1]+lazy[u])%p;
lazy[u<<1|1]=(lazy[u<<1|1]+lazy[u])%p;
lazy[u]=0;
}
void build(int u,int l,int r)
{
if(l==r)
{
sum[u]=a[l];
return;
}
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
updata(u);
}
void upd(int u,int l,int r,int L,int R,long long k)
{
if(r<L||l>R)
{
return;
}
if(l>=L&&r<=R)
{
sum[u]=(sum[u]+k*(r-l+1))%p;
lazy[u]=(lazy[u]+k)%p;
return;
}
lazydown(u,l,r);
int mid=l+r>>1;
if(L<=mid)
{
upd(u<<1,l,mid,L,R,k);
}
if(r>=mid+1)
{
upd(u<<1|1,mid+1,r,L,R,k);
}
updata(u);
}
long long que(int u,int l,int r,int L,int R)
{
if(r<L||l>R)
{
return 0;
}
if(l>=L&&r<=R)
{
return sum[u];
}
lazydown(u,l,r);
int mid=l+r>>1;
long long res=0;
if(L<=mid)
{
res=(res+que(u<<1,l,mid,L,R))%p;
}
if(R>=mid+1)
{
res=(res+que(u<<1|1,mid+1,r,L,R))%p;
}
return res;
}
inline void addpath(int x,int y,long long k)
{
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])
{
upd(1,1,n,id[top[x]],id[x],k); // 贡献
x=fat[top[x]]; // 往上跳
}
else
{
upd(1,1,n,id[top[y]],id[y],k); // 贡献
y=fat[top[y]]; // 往上跳
}
}
upd(1,1,n,min(id[x],id[y]),max(id[x],id[y]),k); // 最后的一小段
}
inline long long quepath(int x,int y) // 同理
{
long long res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])
{
res=(res+que(1,1,n,id[top[x]],id[x]))%p;
x=fat[top[x]];
}
else
{
res=(res+que(1,1,n,id[top[y]],id[y]))%p;
y=fat[top[y]];
}
}
res=(res+que(1,1,n,min(id[x],id[y]),max(id[x],id[y])))%p;
return res;
}
inline void addsubtree(int u,long long k)
{
upd(1,1,n,id[u],R[u],k);
}
inline long long quesubtree(int u)
{
return que(1,1,n,id[u],R[u]);
}
int main()
{
scanf("%d%d%d%lld",&n,&m,&r,&p);
for(int i=1;i<=n;i++)
{
scanf("%lld",&val[i]);
val[i]%=p;
}
for(int i=1;i<=n-1;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs1(r,0);
dfs2(r,r,0);
build(1,1,n);
while(m--)
{
int ty;
scanf("%d",&ty);
if(ty==1)
{
int x,y;
long long z;
scanf("%d%d%lld",&x,&y,&z);
addpath(x,y,z);
}
else if(ty==2)
{
int x,y;
scanf("%d%d",&x,&y);
printf("%lld\n",quepath(x,y));
}
else if(ty==3)
{
int x;
long long z;
scanf("%d%lld",&x,&z);
addsubtree(x,z);
}
else
{
int x;
scanf("%d",&x);
printf("%lld\n",quesubtree(x));
}
}
return 0;
}