树上最小 k 覆盖问题学习笔记

树上最小 k 覆盖问题是一种很典型的树上贪心问题,这里做一下小结。

树上最小 k 覆盖问题的形式一般是:

给定一棵树,边有边权,点 uu 有满足 du{0,1}d_u\in\{0,1\} 的点权 dud_u,称所有满足 du=1d_u=1uu 为“关键点”。

你需要选中一些点,令所有的“关键点”被这些点覆盖。一个“关键点”xx 被“覆盖”了当且仅当存在一个你选中的点 yy,满足 xxyy 之间的距离小于等于 kk,你需要最小化选中的点的数量。

这个问题有诸多变形,比较常见的是配合二分答案来考

我们可以使用树上贪心来解决这个问题。

首先设 fuf_u 表示 uu 子树内距离 uu 最远的未被覆盖的关键点距 uu 的距离,显然 fu=max{fx+1,xsonu}f_u=\max\{f_x+1,x\in son_u\}

考虑距离 uu 最远的未被覆盖的关键点 vvvv 显然有可能被 uu 子树里的点覆盖,也有可能被 uu 的父亲覆盖,或者被 uu 的兄弟子树里的点覆盖。那么我们先考虑被 uu 的子树里的点覆盖的情况,显然画出来时这样的:(黄色节点代表 vv,绿色节点代表一个已经被选中且覆盖黄色节点的点 ww

显然,vvww 不可能是 uu 的同一个儿子的子孙,因为那样的话 vv 早就已经被 ww 覆盖了。所以 vvww 之间的距离就是 vvuu 的距离加上 wwuu 的距离,也就是 fuf_u 加上 wwuu 的距离。那么我们就要wwuu 的距离尽可能短,不妨设它为 gug_u,那么显然 gu=min{gx+1,xsonu}g_u=\min\{g_x+1,x\in son_u\}

接下来分类讨论:

  • fu+gukf_u+g_u\le k,那么 uu 的子树能自给自足,这时 fu=inff_u=-\inf,表示 uu 的字树内没有关键点未被覆盖;

  • fu=kf_u=k,说明 vv 必须被 uu 覆盖了,这时需要选中 uu,同时整棵 uu 的子树都会被覆盖,则 fu=inff_u=-\infgu=0g_u=0ans++ans++

  • du=1d_u=1gu>kg_u>k,说明 uu 是关键点且不能被它子树内之前选中的点覆盖,那么果断交给自己的父亲和兄弟子树,fu=max(fu,0)f_u=\max(f_u,0)

这三种情况有可能同时成立多种,但并不互相干扰,所以需要写成三个 if 并行

最后注意f10f_1\ge0 说明还有关键点没被覆盖,等着 11 号节点的父亲和兄弟子树覆盖它们。但是 11 号节点没有父亲和兄弟,此时直接选中 11 号节点即可,ans++ans++

例题代码如下:

#include <iostream>
#include <cstdio>

using namespace std;

const long long S=1000005,MS=300005;
const int inf=1e8;

int n,m,d[MS];
int esum,to[S],nxt[S],h[MS];
int tot,f[MS],g[MS];

inline void add(int x,int y)
{
	to[++esum]=y;
	nxt[esum]=h[x];
	h[x]=esum;
}

void dfs(int u,int fa,int mid)
{
	f[u]=-inf;
	g[u]=inf;
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa)
		{
			continue;
		}
		dfs(v,u,mid);
		f[u]=max(f[u],f[v]+1);
		g[u]=min(g[u],g[v]+1);
	}
	if(f[u]+g[u]<=mid)
	{
		f[u]=-inf;
	}
	if(d[u]==1&&g[u]>mid)
	{
		f[u]=max(f[u],0);
	}
	if(f[u]==mid)
	{
		f[u]=-inf;
		g[u]=0;
		tot++;
	}
}

inline bool check(int mid)
{
	tot=0;
	dfs(1,0,mid);
	if(f[1]>=0)
	{
		tot++;
	}
	return tot<=m;
}

int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		scanf("%d",&d[i]);
	}
	for(int i=1;i<=n-1;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	int l=0,r=n,ans=-1;
	while(l<=r)
	{
		int mid=l+r>>1;
		if(check(mid))
		{
			ans=mid;
			r=mid-1;
		}
		else
		{
			l=mid+1;
		}
	}
	printf("%d\n",ans);
	return 0;
}

练习题目