【2023NOI模拟赛17】逆转函数 做题记录

对于一个长度为 kk,值域 [1,m][1,m] 的正整数序列 {a1,a2,a3,,ak}\{a_1,a_2,a_3,\dots,a_k\},定义它的逆转函数 ff 为满足 {f(a1),f(a2),f(a3),,f(ak)}={ak,ak1,ak2,,a1}\{f(a_1),f(a_2),f(a_3),\dots,f(a_k)\}=\{a_k,a_{k-1},a_{k-2},\dots,a_1\} 的定义域和值域均为 [1,m][1,m] 的函数。

给定长度为 nn,值域 [1,m][1,m] 的正整数序列 aa,定义 g(l,r)g(l,r){al,al+1,al+2,,ar}\{a_l,a_{l+1},a_{l+2},\dots,a_r\} 的逆转函数个数,你需要求出 l=1nr=lng(l,r)\sum\limits_{l=1}^{n}\sum\limits_{r=l}^n g(l,r)998244353998244353 取模的结果。

1n,m3×1051\le n,m\le 3\times 10^5

不难发现,若一个序列存在逆转函数,那么它的逆转函数的个数就是 mmcntm^{m-cnt},其中 cntcnt 是这个序列中不同整数的个数。

考虑 O(nm)O(nm) 暴力,由于翻转的性质,枚举中心点显然是更方便的,那么可以暴力从中心点往两边拓展,开桶维护答案。

考虑优化,先考虑长度为奇数的区间,设 a[ileni,i+leni]a_{[i-len_i,i+len_i]} 是以 ii 为中心的最长的合法子区间。不难发现由于逆转函数的性质,对于每个 ii,若存在 p<ip<iip+lenpi\le p+len_p,那么 lenilen_i 至少是 min(len2pi,p+lenpi)\min(len_{2p-i},p+len_p-i),因为这两个区间构成双射:

那么根据这个性质,我们可以直接套 manacher,这样就可以求出每个位置的 lenilen_i

考虑方案数怎么求,在继承区间的时候,若已经知道当前区间内有多少个不同的数 smcntsmcnt,且知道所有 [ij,i+j][i-j,i+j]jmin(len2pi,p+lenpi)j\le \min(len_{2p-i},p+len_p-i))的区间的答案的和 smanssmans,那么在暴力拓展的时候就可以直接利用前驱后继来快速维护这两个东西。

考虑在继承的时候如何快速获取这两个东西,显然继承的区间一定也是由某个区间继承过来的,并且一个区间只会从一个区间继承过来,那么设 ii 是从 faifa_i 继承过来的,继承过来的区间为 [ilen1i,i+len1i][i-len1_i,i+len1_i],暴力拓展后的区间为 [ilen2i,i+len2i][i-len2_i,i+len2_i]

那么在暴力拓展的时候可以开个 vector 维护拓展到每个位置的状态 (smcnt,smans)(smcnt,smans),继承的时候倍增找到第一个 len1j<min(len2pi,p+lenpi)len1_j<\min(len_{2p-i},p+len_p-i)jj 从它的 vector 里继承过来即可。

需要注意的是 len1len1 没有单调性,所以要记一下路径上的 min\min

长度为偶数的情况也是好做的,再跑一次即可。

代码如下:

// Problem: #196. 逆转函数
// Contest: Hydro
// URL: http://oiclass.com/d/AKNOI/p/196
// Memory Limit: 128 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>

using namespace std;

const int S=300005,BS=25;
const int p=998244353;

int n,m;
int a[S],pos[S];
int lft[S],rig[S];
int len1[S],len2[S];
int fa[S][BS],mn[S][BS];
vector<pair<int,int>> sum[S];
int ans;

inline int qpow(int x,int y)
{
	int res=1;
	for(;y>0;y>>=1,x=1ll*x*x%p) res=y&1?1LL*res*x%p:res;
	return res;
}

inline void add(int &x,int y)
{
	x+=y;
	if(x>=p) x-=p;
}

int main()
{
	freopen("invfunc.in","r",stdin);
	freopen("invfunc.out","w",stdout);
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int i=1;i<=m;i++) pos[i]=0;
	for(int i=1;i<=n;i++)
	{
		lft[i]=pos[a[i]];
		pos[a[i]]=i;
	}
	for(int i=1;i<=m;i++) pos[i]=n+1;
	for(int i=n;i>=1;i--)
	{
		rig[i]=pos[a[i]];
		pos[a[i]]=i;
	}
	for(int i=0;i<=n;i++)
	{
		len1[i]=len2[i]=0;
		memset(fa[i],0,sizeof(fa[i]));
		memset(mn[i],127,sizeof(mn[i]));
		sum[i].clear();
	}
	memset(mn[0],0,sizeof(mn[0]));
	sum[0].push_back(make_pair(1,qpow(m,m-1)));
	for(int i=1,j=0;i<=n;i++)
	{
		int pp=0;
		if(i<=j+len2[j])
		{
			pp=2*j-i;
			len1[i]=min(len2[pp],j+len2[j]-i);
		}
		fa[i][0]=pp;
		mn[i][0]=len1[i];
		for(int k=1;k<=BS-3;k++)
		{
			fa[i][k]=fa[fa[i][k-1]][k-1];
			mn[i][k]=min(mn[i][k-1],mn[fa[i][k-1]][k-1]);
		}
		int pp2=pp;
		for(int k=BS-3;k>=0;k--) if(mn[pp2][k]>len1[i]) pp2=fa[pp2][k];
		sum[i].push_back(sum[pp2][len1[i]-len1[pp2]]);
		len2[i]=len1[i];
		int smcnt=sum[i][0].first,smans=sum[i][0].second;
		len2[i]++;
		while(1)
		{
			if(i-len2[i]<1||i+len2[i]>n)
			{
				len2[i]--;
				break;
			}
			if(rig[i-len2[i]]<i+len2[i])
			{
				int pp=rig[i-len2[i]];
				if(a[i+len2[i]]!=a[2*i-pp])
				{
					len2[i]--;
					break;
				}
			}
			else smcnt++;
			if(lft[i+len2[i]]>=i-len2[i])
			{
				int pp=lft[i+len2[i]];
				if(a[i-len2[i]]!=a[2*i-pp])
				{
					len2[i]--;
					break;
				}
			}
			else smcnt++;
			add(smans,qpow(m,m-smcnt));
			sum[i].push_back(make_pair(smcnt,smans));
			len2[i]++;
		}
		if(i+len2[i]>j+len2[j]) j=i;
	}
	for(int i=1;i<=n;i++) add(ans,sum[i][sum[i].size()-1].second);
	for(int i=0;i<=n;i++)
	{
		len1[i]=len2[i]=-1;
		memset(fa[i],0,sizeof(fa[i]));
		memset(mn[i],127,sizeof(mn[i]));
		sum[i].clear();
	}
	memset(mn[0],-1,sizeof(mn[0]));
	sum[0].push_back(make_pair(0,0));
	for(int i=1,j=0;i<=n-1;i++)
	{
		int pp=0;
		if(i<=j+1+len2[j])
		{
			pp=2*j-i;
			len1[i]=min(len2[pp],j+len2[j]-i);
		}
		fa[i][0]=pp;
		mn[i][0]=len1[i];
		for(int k=1;k<=BS-3;k++)
		{
			fa[i][k]=fa[fa[i][k-1]][k-1];
			mn[i][k]=min(mn[i][k-1],mn[fa[i][k-1]][k-1]);
		}
		int pp2=pp;
		for(int k=BS-3;k>=0;k--) if(mn[pp2][k]>len1[i]) pp2=fa[pp2][k];
		sum[i].push_back(sum[pp2][len1[i]-len1[pp2]]);
		len2[i]=len1[i];
		int smcnt=sum[i][0].first,smans=sum[i][0].second;
		len2[i]++;
		while(1)
		{
			if(i-len2[i]<1||i+1+len2[i]>n)
			{
				len2[i]--;
				break;
			}
			if(rig[i-len2[i]]<i+1+len2[i])
			{
				int pp=rig[i-len2[i]];
				if(a[i+1+len2[i]]!=a[2*i+1-pp])
				{
					len2[i]--;
					break;
				}
			}
			else smcnt++;
			if(lft[i+1+len2[i]]>=i-len2[i])
			{
				int pp=lft[i+1+len2[i]];
				if(a[i-len2[i]]!=a[2*i+1-pp])
				{
					len2[i]--;
					break;
				}
			}
			else smcnt++;
			add(smans,qpow(m,m-smcnt));
			sum[i].push_back(make_pair(smcnt,smans));
			len2[i]++;
		}
		if(i+len2[i]>j+len2[j]) j=i;
	}
	for(int i=1;i<=n-1;i++) add(ans,sum[i][sum[i].size()-1].second);
	printf("%d\n",ans);
	return 0;
}