给定一棵 个节点的无标号基环树和颜色数 ,给每个节点染上 种颜色中的一种,求有多少种本质不同的染色方案。
,。
首先有根树的方案数是好算的,把儿子按同构情况分类(判断同构使用树哈希),同一类用插板法处理即可。因为:
且组合数的 ,所以有根树的方案数可以在 的时间复杂度内算出。
对于基环树,找出环后求出环上每一个点的子树的答案和哈希值,然后套 Burnside 引理即可。
具体考虑枚举循环移位的次数,用序列哈希判断循环移位后是否同构。设移位了 次,则会形成 个置换环( 为环上节点按顺序形成的序列)。设置换环集合为 ,则由于同一个置换环中的所有元素必定同构,所以只需要算 即可,其中 表示 子树的答案。
而注意到 一定在不同的置换环中,所以 Burnside 引理中的 即为 。
时间复杂度 ,代码如下:
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
using namespace std;
typedef unsigned long long ull;
const int S=300005,p=1000000007;
const ull bse=114513,mask=1145143191981ull;
int inv[S];
int n,m;
vector<int> g[S],son[S];
bool vis[S],ins[S];
int top,sta[S];
vector<int> cir;
ull h[S];
int res[S];
ull mul[S],sum[S];
inline int qpow(int x,int y)
{
int res=1;
for(;y>0;y>>=1,x=1ll*x*x%p) res=y&1?1ll*res*x%p:res;
return res;
}
inline int C(int n,int m)
{
int res=1;
for(int i=n;i>=n-m+1;i--) res=1ll*res*((i%p+p)%p)%p;
for(int i=1;i<=m;i++) res=1ll*res*inv[i]%p;
return res;
}
void dfs(int u)
{
vis[u]=true;
sta[++top]=u;
ins[u]=true;
for(int v:g[u])
{
if(vis[v])
{
if(ins[v])
{
while(1)
{
cir.push_back(sta[top--]);
if(sta[top+1]==v) break;
}
}
}
else dfs(v);
if(!cir.empty()) return;
}
top--;
ins[u]=false;
}
inline ull shift(ull x)
{
x^=mask;
x^=x<<13;
x^=x>>7;
x^=x<<17;
x^=mask;
return x;
}
void calc(int u)
{
for(int v:son[u]) calc(v);
sort(son[u].begin(),son[u].end(),[&](int x,int y){return h[x]<h[y];});
h[u]=1145141;
for(int v:son[u]) h[u]+=shift(h[v]);
res[u]=m;
for(int i=0,cnt=0;i<son[u].size();i++)
{
cnt++;
if(i==son[u].size()-1||h[son[u][i]]!=h[son[u][i+1]])
{
res[u]=1ll*res[u]*C(res[son[u][i]]+cnt-1,cnt)%p;
cnt=0;
}
}
}
inline int gcd(int x,int y)
{
if(x==0||y==0) return x+y;
int t=x%y;
while(t!=0) x=y,y=t,t=x%y;
return y;
}
int main()
{
freopen("color.in","r",stdin);
freopen("color.out","w",stdout);
mul[0]=1;
for(int i=1;i<=S-3;i++) inv[i]=qpow(i,p-2),mul[i]=mul[i-1]*bse;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
int x;
scanf("%d",&x);
g[x].push_back(i);
}
for(int i=1;i<=n;i++)
{
if(!vis[i])
{
while(top>0) ins[sta[top--]]=false;
dfs(i);
if(cir.size()>0) break;
}
}
int len=cir.size();
for(int i=0;i<len/2;i++) swap(cir[i],cir[len-i-1]);
for(int i=1;i<=n;i++) vis[i]=false;
for(int i=0;i<len;i++) vis[cir[i]]=true;
for(int i=1;i<=n;i++)
{
if(vis[i]) continue;
for(int v:g[i]) son[i].push_back(v);
}
for(int i=0;i<len;i++)
{
int u=cir[i],r=cir[(i+1)%len];
for(int v:g[u]) if(v!=r) son[u].push_back(v);
calc(u);
}
sum[0]=h[cir[0]];
for(int i=1;i<len;i++) sum[i]=sum[i-1]*bse+h[cir[i]];
int ans=0,cnt=0;
for(int i=0;i<len;i++)
{
ull pre=sum[len-1];
if(i>0) pre=(sum[len-1]-sum[i-1]*mul[len-i])*mul[i]+sum[i-1];
if(pre==sum[len-1])
{
cnt++;
int now=1,g=gcd(len,i);
for(int j=0;j<g;j++) now=1ll*now*res[cir[j]]%p;
ans=(ans+now)%p;
}
}
ans=1ll*ans*qpow(cnt,p-2)%p;
printf("%d\n",ans);
return 0;
}