点分治学习笔记

前置知识:树的重心

点分治是一种用来处理树上路径问题的思想,而非一种具体分治算法。

先看一到例题

给出一棵 nn 个点的无根树,求树上长度为 kk 的路径条数。

朴素的做法是枚举路径开头的点,跑 O(n)O(n) 树形 dp。但是这样的时间复杂度是 O(n2)O(n^2) 的,不够优秀。

考虑以点 uu 为根的子树的情况,显然所有路径分成两类:

  1. 经过 uu 的;
  2. 不经过 uu,在 uu 儿子的子树中的;

第二类路径可以递归到 uu 的儿子子树中处理,而第一类路径可以 O(n)O(n) 遍历整棵子树来处理。

若递归的时候直接让根为 uu 的某个儿子,显然来个链的情况就能把这个算法卡到 O(n2)O(n^2),所以每次选择的根很重要。

回忆一下重心:

定义:以重心为根,最大的子树大小是以所有节点为根的有根树中最小的。

性质:以重心为根,所有子树的大小都不超过整棵树的大小的一半

所以只要每次选择子树的重心为根就行了,因为重心的优美性质,所以这样递归的层数是 O(logn)O(\log n) 的,时间复杂度为 O(nlogn)O(n\log n)

回到例题,由于点分治的常数巨大,所以建议把所有询问离线下来做:

#include <iostream>
#include <cstdio>

using namespace std;

const int S=50005,BS=10000005;

int n,m,mxk,ks[S];
int esum,to[S],c[S],nxt[S],h[S];
int rot,siz[S],mx[S];
bool hasdis[BS],ans[S];
bool vis[S];

inline void add(int x,int y,int w)
{
	to[++esum]=y;
	c[esum]=w;
	nxt[esum]=h[x];
	h[x]=esum;
}

void gethev(int u,int fa,int sizz)
{
	siz[u]=1;
	mx[u]=0;
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa||vis[v]) continue;
		gethev(v,u,sizz);
		siz[u]+=siz[v];
		mx[u]=max(mx[u],siz[v]);
	}
	mx[u]=max(mx[u],sizz-siz[u]);
	if(rot==-1||mx[u]<mx[rot]) rot=u;
}

void getdis(int u,int fa,int dis)
{
	for(int i=1;i<=m;i++) if(ks[i]>=dis) ans[i]|=hasdis[ks[i]-dis];
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i],w=c[i];
		if(v==fa||vis[v]) continue;
		getdis(v,u,dis+w);
	}
}

void adddis(int u,int fa,int dis,bool f)
{
	if(dis<=mxk) hasdis[dis]=f;
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i],w=c[i];
		if(v==fa||vis[v]) continue;
		adddis(v,u,dis+w,f);
	}
}

void slove(int u)
{
	hasdis[0]=true;
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i],w=c[i];
		if(vis[v]) continue;
		getdis(v,u,w);
		adddis(v,u,w,true);
	}
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i],w=c[i];
		if(vis[v]) continue;
		adddis(v,u,w,false);
	}
	vis[u]=true;
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(vis[v]) continue;
		rot=-1;
		gethev(v,u,siz[v]);
		slove(rot);
	}
}

int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n-1;i++)
	{
		int x,y,w;
		scanf("%d%d%d",&x,&y,&w);
		add(x,y,w);
		add(y,x,w);
	}
	for(int i=1;i<=m;i++) scanf("%d",&ks[i]),mxk=max(mxk,ks[i]);
	rot=-1;
	gethev(1,0,n);
	slove(rot);
	for(int i=1;i<=m;i++) puts(ans[i]?"AYE":"NAY");
	return 0;
}