Boruvka 算法 学习笔记

算法简介

Boruvka 算法是一种用来求解最小生成树的算法,它的流程如下:

  1. 刚开始每个点自己为一个连通块,接下来执行步骤 2 至 3 直到所有点在同一个连通块中;
  2. 对于每个连通块 ii,找到距离它最近,不在连通块中且与连通块有直接连边的点 toito_i,记离 toito_i 最近的连通块中的点为 ouiou_i
  3. 对于每个连通块 ii,若 ouiou_itoito_i 不在同一连通块中,则在生成树中新增边 (oui,toi)(ou_i,to_i),合并 ouiou_itoito_i 所在的连通块;

这样每次合并连通块后连通块个数至少会减半,所以时间复杂度是 O((T+n)logn)O((T+n)\log n) 的,其中 TT 是步骤 22 的时间复杂度。

这个算法在最小生成树板子上的表现平平无奇,但是却十分擅长处理边很多的图(稠密图)的最小生成树问题。

经典例题

AT_cf17_final_j Tree MST

给定一棵 nn 个节点的树,现有有一张完全图,两点 x,yx,y 之间的边长为 wx+wy+disx,yw_x+w_y+dis_{x,y},其中 disdis 表示树上两点的距离。

求完全图的最小生成树。

n2×105n \leq 2 \times 10^5

题解

直接上 Boruvka,找 ouiou_itoito_i 可以做一次 up and down DP。注意到 xx 到其它点的路径可以看作是向上走再向下走,那么跑两次 dfs。第一遍 dfs 从根把信息上推,求出每个节点距离其子树内最近与最近点不在一个连通块的最近的点;第二遍 dfs 把信息下放,下放到 xx 的时候就能求出距离 xx 最近的与 xx 不在一个连通块中的点。

时间复杂度 O(nlogn)O(n\log n)

代码如下:

// Problem: Tree MST
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/AT_cf17_final_j
// Memory Limit: 256 MB
// Time Limit: 5000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>

using namespace std;

const int S=500005;

int n;
long long w[S];
int esum,to[S],nxt[S],h[S];
long long c[S],dep[S];
int col[S];
int ansid[S];
long long ans[S];

struct node
{
	int id;
	int id0,id1;
	long long mn0,mn1;
	inline void init()
	{
		id=id0=id1=0;
		mn0=mn1=1e17;
	}
	inline void operator+=(node &b)
	{
		vector<pair<long long,int>> vec;
		vec.push_back(make_pair(mn0,id0));
		vec.push_back(make_pair(mn1,id1));
		vec.push_back(make_pair(b.mn0,b.id0));
		vec.push_back(make_pair(b.mn1,b.id1));
		sort(vec.begin(),vec.end());
		id0=vec[0].second,mn0=vec[0].first;
		for(int i=1;i<4;i++)
		{
			if(col[vec[i].second]!=col[id0])
			{
				id1=vec[i].second,mn1=vec[i].first;
				break;
			}
		}
	}
}dp[S];

inline void add(int x,int y,long long w)
{
	to[++esum]=y;
	c[esum]=w;
	nxt[esum]=h[x];
	h[x]=esum;
}

void initdep(int u,int fa)
{
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa) continue;
		dep[v]=dep[u]+c[i];
		initdep(v,u);
	}
}

void updfs(int u,int fa)
{
	dp[u].init();
	dp[u].id=u;
	dp[u].id0=u,dp[u].mn0=dep[u]+w[u];
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa) continue;
		updfs(v,u);
		dp[u]+=dp[v];
	}
}

void dwndfs(int u,int fa)
{
	if(fa==0)
	{
		dp[u].mn0-=dep[u]*2;
		dp[u].mn1-=dep[u]*2;
	}
	if(col[u]!=col[dp[u].id0])
	{
		ansid[u]=dp[u].id0;
		ans[u]=dep[u]+w[u]+dp[u].mn0;
	}
	else
	{
		ansid[u]=dp[u].id1;
		ans[u]=dep[u]+w[u]+dp[u].mn1;
	}
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa) continue;
		dp[v].mn0-=dep[v]*2;
		dp[v].mn1-=dep[v]*2;
		dp[v]+=dp[u];
		dwndfs(v,u);
	}
}

int fnd(int u)
{
	return col[u]==u?u:col[u]=fnd(col[u]);
}

int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%lld",&w[i]);
	for(int i=1;i<=n-1;i++)
	{
		int x,y;
		long long w;
		scanf("%d%d%lld",&x,&y,&w);
		add(x,y,w),add(y,x,w);
	}
	initdep(1,0);
	for(int i=1;i<=n;i++) col[i]=i;
	long long rss=0;
	while(1)
	{
		for(int i=1;i<=n;i++) col[i]=fnd(i);
		bool f=true;
		for(int i=1;i<=n;i++) f&=col[i]==col[1];
		if(f) break;
		updfs(1,0);
		dwndfs(1,0);
		for(int i=1;i<=n;i++)
		{
			int rt=fnd(i);
			if(ans[i]<ans[rt]) ans[rt]=ans[i],ansid[rt]=ansid[i];
		}
		for(int i=1;i<=n;i++)
		{
			int rt=fnd(i);
			if(rt!=fnd(ansid[rt]))
			{
				col[rt]=fnd(ansid[rt]);
				rss+=ans[rt];
			}
		}
	}
	printf("%lld\n",rss);
	return 0;
}

CF888G Xor-MST
CF1550F Jumping Around
CF1648E Air Reform
CF1305G Kuroni and Antihype
题解

加入超级源点 00,令边 (x,y)(x,y) 的边权为 ax+aya_x+a_y,那么答案即为最大生成树的边权和减去 ai\sum a_i

那么直接 Brouvka,每次维护子集最大值、和最大值不在同一个连通块的最大值即可。

代码如下:

#include <iostream>
#include <algorithm>
#include <cstdio>

using namespace std;

const int S=200005,LM=1<<18;

int n,a[S];
int fa[S];
int mu[S],mv[S];

int fnd(int x){return fa[x]==x?x:fa[x]=fnd(fa[x]);}

struct node
{
	int x,y;
	node(){x=y=n;}
	inline void init(){x=y=n;}
	inline node operator+(node b)
	{
		node re;
		int res[4]={x,y,b.x,b.y};
		sort(res,res+4,[&](int x,int y){return a[x]>a[y];});
		re.x=res[0],re.y=n;
		for(int i=1;i<4;i++)
		{
			if(fnd(res[i])!=fnd(res[0]))
			{
				re.y=res[i];
				break;
			}
		}
		return re;
	}
}mx[LM];

int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	a[++n]=0;
	for(int i=1;i<=n;i++) fa[i]=i;
	long long ans=0;
	while(1)
	{
		for(int i=0;i<LM;i++) mx[i].init();
		for(int i=1;i<=n;i++) mx[a[i]].x=i;
		for(int i=0;i<LM;i++)
		{
			for(int j=0;(1<<j)<=i;j++)
			{
				if(i>>j&1) mx[i]=mx[i]+mx[i^(1<<j)];
			}
		}
		for(int i=1;i<=n;i++) mu[i]=mv[i]=-1;
		for(int u=1;u<=n;u++)
		{
			int rt=fnd(u);
			int v=mx[a[u]^(LM-1)].x;
			if(fnd(v)==rt) v=mx[a[u]^(LM-1)].y;
			if(fnd(v)==rt) continue;
			if(mu[rt]==-1||a[u]+a[v]>a[mu[rt]]+a[mv[rt]]) mu[rt]=u,mv[rt]=v;
		}
		bool f=true;
		for(int i=1;i<=n;i++)
		{
			if(fa[i]==i)
			{
				int u=mu[i],v=mv[i];
				if(u==-1) continue;
				if(fnd(u)==fnd(v)) continue;
				f=false;
				ans+=a[u]+a[v];
				fa[fnd(u)]=fnd(v);
			}
		}
		if(f) break;
	}
	for(int i=1;i<=n;i++) ans-=a[i];
	printf("%lld\n",ans);
	return 0;
}