AGC058F Authentic Tree DP 做题记录

对于 nn 个点的一棵无根树 TT,定义 f(T)f(T)

  • n=1n=1,则 f(T)=1f(T)=1
  • 否则:
    • 对于一条树边 ee,定义 Tx(e)T_x(e)Ty(e)T_y(e)TT 删去 ee 后分裂出的两棵树(不管顺序);
    • f(T)=(eedge(T)f(Tx(e))×f(Ty(e)))×1nf(T)=\left(\sum\limits_{e\in \text{edge}(T)} f(T_x(e))\times f(T_y(e))\right)\times \frac{1}{n}

给定一棵 nn 个点的无根树 AA,求 f(A) mod 998244353f(A)\text{ mod }998244353 的值。

2n50002\le n\le 5000

发现有个 1n\frac{1}{n} 不好处理,考虑组合意义。

若删边改成删点,这个 1n\frac{1}{n} 就相当于每次等概率随机选择一个点删掉,所以 f(T)=1f(T)=1。同理,若乘的是 1n1\frac{1}{n-1},则 f(T)f(T) 也会恒为 11

接下来开始人类智慧:

  • 为每条边建一个“边点”,再给“边点”连 1-1 个点(可以认为是 9982443531998244353-1 个),那么现在点数变成了 nn,答案即为等概率随机一个 TT 的点排列,满足每个”边点“都在其相邻点之前的概率;

这个结论的证明挺显然的,因为若先删掉”边点“相邻的点则树的形态不合法。

接下来考虑把每条边按照点的先后关系定向,则相当于求满足定向关系的概率。发现有的父亲比儿子大,有的父亲比儿子小,不好计算。那么不妨容斥,对于每条向上的边,其可以删掉或者改成向下的边,每个由向上改为向下的边都会贡献 1-1 的容斥系数。

那么就可以做了,记录一下从每个点开始的外向树大小,背包即可。

时间复杂度 O(n2)O(n^2)

#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

const int S=5005,p=998244353;

int n;
int fra[S],inv[S];
vector<int> g[S];
int siz[S],tf[S],f[S][S];

inline int qpow(int x,int y)
{
	int res=1;
	for(;y>0;y>>=1,x=1ll*x*x%p) res=y&1?1ll*res*x%p:res;
	return res; 
}

inline void add(int &x,int y)
{
	x+=y;
	if(x>=p) x-=p;
}

void dfs(int u,int fa)
{
	siz[u]=1;
	f[u][1]=1;
	for(int v:g[u])
	{
		if(v==fa) continue;
		dfs(v,u);
		for(int i=1;i<=siz[u]+siz[v];i++) tf[i]=0;
		for(int i=1;i<=siz[u];i++)
		{
			for(int j=1;j<=siz[v];j++) 
			{
				int pre=1ll*f[v][j]*inv[j]%p;
				add(tf[i+j],p-1ll*f[u][i]*pre%p);
				add(tf[i],1ll*f[u][i]*pre%p);
			}
		}
		siz[u]+=siz[v];
		for(int i=1;i<=siz[u];i++) f[u][i]=tf[i];
	}
	for(int i=1;i<=siz[u];i++) f[u][i]=1ll*f[u][i]*inv[i]%p;
}

int main()
{
	fra[0]=1;
	for(int i=1;i<=S-3;i++) fra[i]=1ll*fra[i-1]*i%p;
	inv[S-3]=qpow(fra[S-3],p-2);
	for(int i=S-3;i>=1;i--) inv[i-1]=1ll*inv[i]*i%p;
	for(int i=1;i<=S-3;i++) inv[i]=1ll*inv[i]*fra[i-1]%p;
	scanf("%d",&n);
	for(int i=1;i<=n-1;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		g[x].push_back(y);
		g[y].push_back(x);
	}
	dfs(1,0);
	int ans=0;
	for(int i=1;i<=n;i++) add(ans,f[1][i]);
	printf("%d\n",ans);
	return 0;
}