长链剖分学习笔记

Part 1 定义

在长链剖分中,我们定义 uu 的长儿子为:

  • uu 的所有儿子 vv 中,子树最高的那个 vv

与重链剖分一样,每个点和它的长儿子划分到同一条长链。

Part 2 性质

2.1 每个点属于且仅属于一条长链

2.2 一个点到根的路径上长链的条数是 O(n)O(\sqrt n)

对于一个子树高度为 kk 的节点 uu,若它的父亲 fafauu 并不在一条长链中,则 fafa 的子树大小至少为 2k+12k+1

那么如果跳了 xx 条长链,则子树大小至少为 i=1xi\sum\limits_{i=1}^x i

所以一个节点到根的路径上,长链的条数是 O(n)O(\sqrt n) 的。

2.3 一个点的 kk 级祖先所在的长链的长度至少为 kk

Part 3 应用

3.1 O(nlogn)/O(1)O(n\log n)/O(1) 在线求 kk 级祖先

根据性质 2.1 和 2.3,不难想到 uukk 级祖先距离 uu2log2k2^{\lfloor\log_2k\rfloor} 级祖先所在的长链 lnklnk 的链头一定不超过 lnk|lnk| 个节点。

那么维护出每个点 uu2j2^j 级祖先 fau,jfa_{u,j},并维护每条长 lenlen 的长链链头的 0len0\sim len 级祖先和链上节点的顺序。

查询的时候只需要找到 uu2log2k2^{\lfloor\log_2k\rfloor} 级祖先 fau,log2kfa_{u,\lfloor\log_2k\rfloor},然后从链头开始跳即可。

3.2 优化树形 dp

某些树形 dp 的第二维只和深度有关系,且转移时的整体变化可以快速维护,则可以使用长链剖分优化。

例如 CF1009F Dominant Indices

给定一棵以 11 为根,nn 个节点的树。设 d(u,x)d(u,x)uu 子树中到 uu 距离为 xx 的节点数。

对于每个点,求一个最小的 kk,使得 d(u,k)d(u,k) 最大。

不难想到一个朴素的 dp:设 fu,if_{u,i} 表示点 uu 子树内距离 uuii 的节点个数。

转移:

fu,0=1fu,j=vsonufv,j1f_{u,0}=1\\ f_{u,j}=\sum\limits_{v\in son_u} f_{v,j-1}

考虑优化,设点 uu 距离它所在长链的链头 tputp_udepudep_u,则考虑用 tputp_u 向下 depu+jdep_u+j 个位置的点来保存 uu 的信息。根据性质 2.2,这样的点一定存在。

那么把 fu,jf_{u,j} 挂到 ftpu,depu+jf_{tp_u,dep_{u}+j} 上,则每个点的转移相当于直接继承它长儿子的信息,并把其它儿子的长链上信息暴力合并。

由于每条长链只会被合并一次且长链的总长为 nn,所以这样做的时间复杂度为 O(n)O(n)

代码

#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

const int S=1000005;

int n;
vector<int> g[S];
int len[S],mx[S];
int top[S],dep[S];
int mxp[S];
vector<int> f[S];
int ans[S];

void dfs(int u,int fa)
{
	for(int v:g[u])
	{
		if(v==fa) continue;
		dfs(v,u);
		len[u]=max(len[u],len[v]);
		if(len[v]>len[mx[u]]) mx[u]=v;
	}
	len[u]++;
}

void dfs2(int u,int fa,int tp)
{
	top[u]=tp;
	dep[u]=u==tp?1:dep[fa]+1;
	if(mx[u]!=0) dfs2(mx[u],u,tp);
	for(int v:g[u])
	{
		if(v==fa||v==mx[u]) continue;
		dfs2(v,u,v);
	}
}

void dfs3(int u,int fa)
{
	if(mx[u]!=0) dfs3(mx[u],u);
	int id=top[u];
	for(int v:g[u])
	{
		if(v==fa||v==mx[u]) continue;
		dfs3(v,u);
		for(int j=1;j<f[v].size();j++)
		{
			int idx=dep[u]+j;
			f[id][idx]+=f[v][j];
			if(f[id][idx]>f[id][mxp[id]]) mxp[id]=idx;
			else if(f[id][idx]==f[id][mxp[id]]&&idx<mxp[id]) mxp[id]=idx;
		}
	}
	int idx=dep[u];
	f[id][idx]++;
	if(f[id][idx]>f[id][mxp[id]]) mxp[id]=idx;
	else if(f[id][idx]==f[id][mxp[id]]&&idx<mxp[id]) mxp[id]=idx;
	ans[u]=mxp[id]-dep[u];
}

int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n-1;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		g[x].push_back(y),g[y].push_back(x);
	}
	dfs(1,0),dfs2(1,0,1);
	for(int i=1;i<=n;i++) f[i]={0};
	for(int i=1;i<=n;i++) f[top[i]].push_back(0);
	dfs3(1,0);
	for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
	return 0;
}

更多例题