QOJ2211 IOI Problem Revised 做题记录

nn 个人在一个长 LL 的环上,第 ii 个人位于坐标 aia_i 处。你要选择 kk 个关键点,最小化每个人到最近的关键点的距离的和。输出方案。

1kn2×1051\le k\le n\le 2\times 10^51L10121\le L\le 10^{12}

首先注意到最近关键点相同的人一定位于区间内,故问题可以转化为将人划分为若干个区间,区间 [l,r][l,r] 代价为 liraial+r2\sum\limits_{l\le i\le r} \left|a_i-a_{\lfloor\frac{l+r}{2}\rfloor}\right|,最小化代价和。

考虑链的情况怎么做,由于 [l,r][l,r][l+1,r1][l+1,r-1] 的中位数相同,故区间代价 wl,rw_{l,r} 满足四边形不等式。

那么这是蒙日矩阵最短路问题,直接 wqs 二分 + 二分队列优化转移即可,复杂度 O(nlogVlogn)O(n\log V\log n)

接下来发扬人类智慧,先随便找个位置(假定为 11 前)将环断开,求出此时的最优解的分界点 aa。那么对于环上的最优解 bb,一定有 bib_i 在区间 [ai1,ai][a_{i-1},a_i] 中:

证明考虑若不是这种情况则一定 有某段红色区间包含两个蓝色端点 且 有某段蓝色区间包含两个红色端点,那么根据相交优于包含,我们可以调整这两处(绿色线),使得这两个方案的代价之和减小(变为紫色和棕色线):

由于此时某条线仍然是从 11 处断开的方案,故要么 aa 不是从 11 处断开的最优方案,要么 bb 不是环上的最优方案,矛盾。

根据这个性质,对于某个 x[1,a1]x\in [1,a_1],求出从 x+1x+1 处断开的最优方案(起始点在 xx 的最优方案)cxc_x,则一定有 c[1,x1],i[ai1,cx,i]c_{[1,x-1],i}\in [a_{i-1},c_{x,i}]c[x+1,a1],i[cx,i,ai]c_{[x+1,a_1],i}\in[c_{x,i},a_{i}]

那么考虑分治,每次求出 midmid 为起始点的最小方案,往两边递归。求最小方案使用分治优化转移,这样每一层都会遍历整个环,复杂度 O(nlog2n)O(n\log^2 n)

但是还有一个问题,每一层可能会多一些点。具体的,注意到每次求一个起始点的最小方案至少是 O(k)O(k) 的,而若选择的区间长度为 ll 则分治底层要跑 ll 次,复杂度至少是 O(lk)O(lk)。所以还需要找到最短的区间进行分治,这样底层复杂度就是 O(nk×k)=O(n)O(\frac{n}{k}\times k)=O(n) 的,正确。

总时间复杂度 O(nlogVlogn+nlog2n)O(n\log V\log n+n\log ^2n),代码如下:

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>

using namespace std;

typedef long long ll;

const int S=200005;
const ll inf=1e18;

int n,k;
ll L,a[S*2],sm[S*2];
ll tmp[S];
ll ans;
vector<int> res;

inline ll getsm(int l,int r){return sm[r]-sm[l-1];}

inline ll calc(int l,int r)
{
	if(l>r) r+=n;
	int m=l+r>>1;
	// printf("[%d %d %d] %lld %lld %lld\n",l,m,r,a[m],getsm(l,m),getsm(m+1,r));
	return a[m]*(m-l+1)-getsm(l,m)+
		   getsm(m+1,r)-a[m]*(r-m);
}

namespace wqs
{
	ll f[S];
	int ctl[S],ctr[S];
	int hed,til,q[S],ql[S],qr[S];
	inline bool cmp(int i1,int i2,int j,bool sml) // i1<i2
	{
		ll x1=f[i1]+calc(i1+1,j);
		ll x2=f[i2]+calc(i2+1,j);
		if(x1!=x2) return x1<x2;
		if(sml) return ctl[i1]<ctl[i2];
		else return ctr[i1]>ctr[i2];
	}
	void get(ll k,bool sml)
	{
		for(int i=1;i<=n;i++) f[i]=inf;
		if(sml) for(int i=1;i<=n;i++) ctl[i]=1e8;
		else for(int i=1;i<=n;i++) ctr[i]=-1e8;
		f[0]=0;
		(sml?ctl:ctr)[0]=0;
		hed=1,til=0;
		q[++til]=0,ql[til]=1,qr[til]=n;
		for(int i=1;i<=n;i++)
		{
			while(qr[hed]<i) hed++;
			int j=q[hed];
			f[i]=f[j]+calc(j+1,i)-k;
			// printf("%d: %d\n",i,j);
			(sml?ctl:ctr)[i]=(sml?ctl:ctr)[j]+1;
			if(qr[hed]==i) hed++;
			else ql[hed]=i+1;
			while(hed<=til&&cmp(i,q[til],ql[til],sml)) til--;
			if(hed>til) q[++til]=i,ql[til]=i+1,qr[til]=n;
			else
			{
				int lb=ql[til],rb=qr[til],res=rb;
				while(lb<=rb)
				{
					int mid=lb+rb>>1;
					if(!cmp(i,q[til],mid,sml)) res=mid,lb=mid+1;
					else rb=mid-1;
				}
				int rr=qr[til];
				qr[til]=res;
				if(res<n) q[++til]=i,ql[til]=res+1,qr[til]=n;
			}
		}
	}
	inline vector<int> slove()
	{
		ll lb=-inf,rb=0,res=0;
		while(lb<=rb)
		{
			ll mid=lb+rb>>1;
			get(mid,false);
			if(ctr[n]>=k) res=mid,rb=mid-1;
			else lb=mid+1;
		}
		get(res,true);
		get(res,false);
		vector<int> vec;
		vec.push_back(n);
		for(int i=n-1,lst=n,tk=k-1;i>=1;i--)
		{
			// printf("%lld %lld %d [%d %d]\n",
				// f[i]+calc(i+1,lst)-res,f[lst],tk,ctl[i],ctr[i]);
			if(f[i]+calc(i+1,lst)-res==f[lst]&&ctl[i]<=tk&&tk<=ctr[i])
			{
				// printf(">> %d\n",i);
				vec.push_back(i);
				lst=i;
				tk--;
			}
		}
		reverse(vec.begin(),vec.end());
		return vec;
	}
}

namespace slove
{
	ll f[S];
	int lst[S*2],idx[S];
	void get(int l,int r,int kl,int kr)
	{
		if(l>r) return;
		int mid=l+r>>1,p=kl;
		f[mid]=inf;
		for(int i=kl;i<=kr;i++)
		{
			ll pre=f[i]+calc(i+1,mid);
			if(pre<f[mid]) f[mid]=pre,p=i;
		}
		lst[mid]=p;
		if(l==r) return;
		get(l,mid-1,kl,p),get(mid+1,r,p,kr);
	}
	void slove(vector<pair<int,int> > &seq)
	{
		int l0=seq[0].first,r0=seq[0].second;
		if(l0>r0) return;
		// for(auto x:seq) printf("[%d %d] ",x.first,x.second);
		// printf("\n");
		int mid=l0+r0>>1;
		f[mid]=0;
		for(int i=1;i<k;i++)
		{
			int l,r;
			if(i==1) l=r=mid;
			else l=seq[i-1].first,r=seq[i-1].second;
			int pl=seq[i].first,pr=seq[i].second;
			if(r==pl)
			{
				get(pl+1,pr,l,r);
				f[pl]=inf;
				for(int j=l;j<=r-1;j++)
				{
					ll pre=f[j]+calc(j+1,pl);
					if(pre<f[pl]) f[pl]=pre,lst[n+pl]=j;
				}
			}
			else get(pl,pr,l,r);
		}
		for(int i=k-1;i>=1;i--)
		{
			int pl=seq[i].first,pr=seq[i].second;
			for(int j=pl;j<=pr;j++) idx[j]=i+1;
		}
		idx[mid]=1;
		ll premn=inf;
		int p=0;
		for(int i=seq[k-1].first;i<=seq[k-1].second;i++)
		{
			ll pre=f[i]+calc(i+1,mid);
			if(pre<premn) premn=pre,p=i;
		}
		// printf("%lld\n",premn);
		vector<int> vec;
		int tk=k;
		vec.push_back(p);
		while(p!=mid)
		{
			if(idx[p]!=tk) p=lst[n+p];
			else p=lst[p];
			tk--;
			vec.push_back(p);
			// printf(">> %d\n",p);
		}
		reverse(vec.begin(),vec.end());
		// printf("]]]] ");
		// for(int x:vec) printf("%d ",x);
		// printf("\n");
		if(premn<ans)
		{
			ans=premn;
			res=vec;
		}
		if(l0==r0) return;
		vector<pair<int,int> > s1,s2;
		s1.emplace_back(l0,mid-1);
		s2.emplace_back(mid+1,r0);
		for(int i=1;i<k;i++)
		{
			int l=seq[i].first,r=seq[i].second,m=vec[i];
			s1.emplace_back(l,m);
			s2.emplace_back(m,r);
		}
		slove(s1),slove(s2);
	}
}

int main()
{
	scanf("%d%d%lld",&n,&k,&L);
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]),a[i+n]=L+a[i];
	for(int i=1;i<=n*2;i++) sm[i]=sm[i-1]+a[i];
	if(k==1)
	{
		ll ans=inf,res=0;
		for(int i=1;i<=n;i++)
		{
			ll pre=calc(i,i+n-1);
			if(pre<ans) ans=pre,res=a[i+i+n-1>>1];
		}
		if(res>=L) res-=L;
		printf("%lld\n%lld\n",ans,res);
		return 0;
	}
	vector<int> pot=wqs::slove();
	// rotate
	int mnd=n,mnp=0;
	for(int i=0;i<k;i++)
	{
		int pre=(pot[(i+1)%k]-pot[i]+n)%n;
		if(pre<mnd) mnd=pre,mnp=i;
	}
	// for(int x:pot) printf("%d ",x);
	// printf("\n");
	// printf(">> %d\n",mnd);
	int beg=pot[mnp]+1;
	for(int i=beg;i<=n;i++) tmp[i-beg+1]=a[i];
	for(int i=1;i<beg;i++) tmp[n-beg+1+i]=L+a[i];
	for(int i=1;i<=n;i++) a[i]=tmp[i],a[i+n]=L+a[i];
	for(int i=1;i<=n*2;i++) sm[i]=sm[i-1]+a[i];
	for(int i=0;i<k;i++) pot[i]=(pot[i]-beg+1-1+n)%n+1;
	sort(pot.begin(),pot.end());
	// for(int i=1;i<=n;i++) printf("%lld ",a[i]);
	// printf("\n");
	// for(int x:pot) printf("%d ",x);
	// printf("\n");
	// printf("%lld\n",calc(1,3)+calc(4,5));
	vector<pair<int,int> > seq;
	for(int i=0,lst=1;i<k;i++)
	{
		seq.emplace_back(lst,pot[i]);
		lst=pot[i];
	}
	pot=wqs::slove();
	ans=0;
	for(int i=0,lst=0;i<k;i++) ans+=calc(lst+1,pot[i]),lst=pot[i];
	res=pot;
	// printf("pot0: %d\n",pot[0]);
	// for(int x:pot) printf("%d ",x);
	// printf("\n");
	// if(pot[0]>n/k) printf("ERR");
	// if(n==200000&&L==1000000000000ll) return 0;
	slove::slove(seq);
	printf("%lld\n",ans);
	vector<ll> tmp;
	for(int i=0;i<k;i++)
	{
		int lb=res[(i-1+k)%k]+1,rb=res[i];
		if(lb>rb) rb+=n;
		tmp.push_back(a[lb+rb>>1]%L);
	}
	sort(tmp.begin(),tmp.end());
	for(ll x:tmp) printf("%lld ",x);
	printf("\n");
	return 0;
}