QOJ5013 Astral Birth 做题记录

给定一个长 nn 的 01 序列 aa,对于每一个 1kn1\le k\le n,求出 aa 划分为 kk 个连续段重排后的最长不下降子序列的最大值。

1n3×1051\le n\le 3\times 10^5

下文称最长不下降子序列为 LNDS。

首先问题等价于划分成 k\le k 个连续段。

不难发现若 ai=ai+1a_i=a_{i+1} 则它们一定会被划分进同一个连续段,那么不妨将它们合并。

于是问题转化为了长度为 mm 的 01 相间序列 bib_i 且每个 ii 有大小 cic_i

发现由于 LNDS 中相邻两个 bib_i 相同的元素可以划分进同一段,所以直接做并不好做。那么考虑把不在 LNDS 中的元素删除,并把相邻两个相同元素合并。那么若最后剩下 ll 个元素,则:

  • LNDS 的长度为这 ll 个元素的 cic_i 之和;
  • l3l\ge 3 则需要划分 l1l-1 段因为 i,bi=0,bi+1=1\exist i,b_i=0,b_{i+1}=1

对于 l2l\le 2 的情况,不妨直接单独做。所以下面默认 l3l\ge 3

不难发现如下性质:

  • 若删除边界元素,元素个数会减少 11,否则会减少 22,因为两边的元素会合并;

  • 一定不会同时删除 bib_ibi+1b_{i+1}

那么先枚举 b1b_1bmb_m 有没有被删,则问题转化为要找一些两两不相邻的元素删去,使得删掉的元素的 cic_i 之和最小。

这是一个经典反悔贪心问题。

不难发现对于未被删除的元素中 cic_i 最小的元素 ii,若 ii 没被删掉则 i1i-1i+1i+1 就一定要都被删掉。

那么用链表+小根堆维护,设当前删掉的元素总和为 smsm,每次找到堆顶 pp

  • smsm 加上 cpc_p

  • 找到 pp 的前驱 ll 和后继 rr

    • llrr 均存在,则 pp 有可能不被删除,令 cp:=cl+crcpc_p:=c_l+c_r-c_p,在堆中加入 (cp,p)(c_p,p)

      此时原来的 pp 已被删除,现在的 ppllrr 合并成的新元素,所以在链表中删除 llrr

    • 否则 pp 一定会被删除,则在链表中删除 pp

      此时 llrr 一定不会被删除,那么在链表中删除 llrr(若存在)。

重复 kk 次该过程则 LNDS 中将会有 m(2k+t)m-(2k+t)tt11mm 中被删除的元素个数)个元素,需要划分为 m(2k+t)1m-(2k+t)-1 段,那么开个答案数组每次对 smsmmin\min 即可。

时间复杂度 O(nlogn)O(n\log n)

代码如下:

#include <iostream>
#include <cstdio>
#include <queue>

using namespace std;

const int S=300005;

int n;
char str[S];
int m,a[S],val[S];
int pre[S],nxt[S];
bool sta[S];
int ans[S];

inline void del(int x)
{
	pre[nxt[x]]=pre[x];
	nxt[pre[x]]=nxt[x];
	sta[x]=false;
}

inline void slove(int x,int y)
{
	for(int i=1;i<=m;i++) val[i]=a[i],pre[i]=i-1,nxt[i]=i+1,sta[i]=true;
	int c=0,sm=0;
	if(x) c++,sm+=val[1],del(nxt[1]);
	if(y) c++,sm+=val[m],del(pre[m]);
	del(1),del(m);
	priority_queue<pair<int,int>> q;
	for(int i=1;i<=m;i++) if(sta[i]) q.push(make_pair(-val[i],i));
	ans[m-c-1]=min(ans[m-c-1],sm);
	while(m-c-1>=1&&!q.empty())
	{
		auto u=q.top();
		q.pop();
		int p=u.second;
		if(!sta[p]) continue;
		c+=2;
		sm+=val[p];
		ans[m-c-1]=min(ans[m-c-1],sm);
		int l=pre[p],r=nxt[p];
		if(l!=0&&r!=m+1)
		{
			val[p]=val[l]+val[r]-val[p];
			q.push(make_pair(-val[p],p));
		}
		else del(p);
		if(l!=0) del(l);
		if(r!=m+1) del(r);
	}
}

int main()
{
	scanf("%d%s",&n,str+1);
	for(int i=1;i<=n;i++)
	{
		if(str[i]=='0')
		{
			if(m==0||a[m]<0) a[++m]=1;
			else a[m]++;
		}
		else
		{
			if(m==0||a[m]>0) a[++m]=-1;
			else a[m]--;
		}
	}
	for(int i=1;i<=m;i++) if(a[i]<0) a[i]=-a[i];
	for(int i=1;i<=n;i++) ans[i]=1e8;
	slove(0,0);
	slove(1,0);
	slove(0,1);
	slove(1,1);
	for(int i=1;i<=n;i++) ans[i]=n-ans[i];
	ans[1]=0;
	int c1=0;
	for(int i=1;i<=n;i++) c1+=str[i]=='1';
	for(int i=1,c0=0;i<=n;i++)
	{
		c0+=str[i]=='0';
		ans[1]=max(ans[1],c0+c1);
		c1-=str[i]=='1';
	}
	for(int i=2;i<=n;i++) ans[i]=max(ans[i],ans[i-1]);
	for(int i=1;i<=n;i++) printf("%d ",ans[i]);
	printf("\n");
	return 0;
}