【2023NOI模拟赛09】点分治 做题记录

小L初学的时候学到了如下的点分治算法:

  • 初始时当前连通块是整棵树。
  • 首先,在当前连通块中找到任意一个点 uu 作为该次的分治中心(不必是重心)。
  • 其次,把点 uu 在当前连通块中删去,可以得到若干个连通块。对于每个连通块再递归进行这样的操作。

不难发现,这个点分治在最坏情况下递归层数可以达到 O(n)O(n) 层。现在,好奇的小 L 想要知道,对于一棵给定的包含 nn 个节点的树,他有多少种不同的点分治方案呢?因为答案可能很大,你只需要输出它对 109+710^9+7 取模的值即可。

两种点分治方案不同当且仅当某一个连通块所选的点分治中心不同。

为了避免因为大家所学算法的具体细节不同出现歧义,我们还提供了一份暴力代码来具体描述这个算法。

const int mod=1e9+7;
const int maxn=5005;
bool vis[maxn];
vector<int> e[maxn];
int n;
inline void view_all(vector<int> &cur,int x,int fa)
{
	cur.push_back(x);
	for(int p: e[x])
	{
		if (vis[p]) continue;
		if (p == fa) continue;
		view_all(cur, p, x);
	}
}
inline int calc(int x)
{
	vector<int> cur;
	int ans = 0;
	view_all(cur, x, -1);
	for(auto = w : cur)
	{
		int res = 1;
		vis[w] = 1;
		for(auto p: e[w])
		{
			if (vis[p]) continue;
			res = 1ll * res * calc(p) % mod;
		}
		vis[w] = 0;
		ans = (ans + res) % mod;
	}
	return ans;
}
inline int get_ans()
{
	return calc(1);
}

1n50001\le n\le 5000

考虑原树上的一条边 (u,v)(u,v)

考虑断掉 (u,v)(u,v) 后对点分树的影响,设断掉 (u,v)(u,v)uu 子树中的点为绿点,vv 子树中的点为红点,那么由于所有红点和所有绿点只有一条边相连,所以断边前的点分树一定只有一条极长链是既包含绿点又包含红点的:

那么假设这条链是这样的:

断边之后为了保证祖先-后代关系,就一定要这样连:

即每个点的新父亲都是它所有祖先中最深的和它颜色一样的点。

注意到由于保证了祖先-后代关系,所以同一个点分树断不同的边得到的结果是互不相同的,同一个点分树断相同的边得到的结果是相同的。

考虑把这个操作倒过来,合并两棵子树。观察到这个操作对 uuvv 这两个点在对应点分树中的子孙都没有影响,所以把 uuvv 到它的子树的点分树的根的链画出来:

那么为了保证不同的两棵小点分树合并出来的点分树不同,相同两棵小点分树合并出来的点分树相同,同样需要保证合并后的祖先-后代关系一致,那么类似归并排序,只要保证这两条链的相对顺序就行了。设绿色链的长度为 xx,红色链的长度为 yy,那么合并的方案数即为 (x+yx)\binom{x+y}{x}

考虑用树上背包维护合并的过程,设 dpu,xdp_{u,x} 表示 uu 的子树合并完毕,uu 在点分树中到根的链长度为 xx 的点分树个数,那么对于每个 uu 的儿子 vv,有转移:

i=1xj=xisizvdpv,j(x1i1)dpu,x\sum\limits_{i=1}^{x}\sum\limits_{j=x-i}^{siz_v}dp_{v,j}\binom{x-1}{i-1}\to dp_{u,x}

即枚举 uu 的祖先中有多少个属于 vv 的点分树的点。注意到这样时间复杂度是 O(n3)O(n^3) 的,注意到组合数与 jj 无关,所以可以做后缀和,这样就可以优化到 O(n2)O(n^2),足以通过本题。

代码如下:

#include <iostream>
#include <cstdio>

using namespace std;

const int S=5005,p=1000000007;

int C[S][S];
int n;
int esum,to[S*2],nxt[S*2],h[S];
int siz[S],dp[S][S],pd[S];

inline void add(int x,int y)
{
	to[++esum]=y;
	nxt[esum]=h[x];
	h[x]=esum;
}

inline void addd(int &x,int y)
{
	x+=y;
	if(x>=p) x-=p;
}

void dfs(int u,int fa)
{
	siz[u]=1,dp[u][1]=1;
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa) continue;
		dfs(v,u);
		for(int j=1;j<=siz[u];j++)
		{
			for(int k=0;k<=siz[v];k++)
			{
				addd(pd[j+k],1ll*dp[u][j]*dp[v][k]%p*C[j+k-1][j-1]%p);
			}
		}
		siz[u]+=siz[v];
		for(int j=1;j<=siz[u];j++) dp[u][j]=pd[j],pd[j]=0;
	}
	for(int i=siz[u]-1;i>=0;i--) addd(dp[u][i],dp[u][i+1]);
}

int main()
{
	freopen("dianfen.in","r",stdin);
	freopen("dianfen.out","w",stdout);
	for(int i=0;i<=S-3;i++)
	{
		C[i][0]=1;
		for(int j=1;j<=i;j++) C[i][j]=(C[i-1][j]+C[i-1][j-1])%p;
	}
	scanf("%d",&n);
	for(int i=1;i<=n-1;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y),add(y,x);
	}
	dfs(1,0);
	printf("%d\n",dp[1][1]);
	return 0;
}