P5298 [PKUWC2018]Minimax 做题记录

P5298 [PKUWC2018]Minimax

首先有暴力转移:

Du,i=Ds1,i((Puj<iDs2,j)+((1Pu)j>iDs2,j))+Ds2,i((Puj<iDs1,j)+((1Pu)j>iDs1,j))D_{u,i}=D_{s1,i}\left(\left(P_u\sum\limits_{j<i}D_{s2,j}\right)+\left((1-P_u)\sum\limits_{j>i}D_{s2,j}\right)\right)+D_{s2,i}\left(\left(P_u\sum\limits_{j<i}D_{s1,j}\right)+\left((1-P_u)\sum\limits_{j>i}D_{s1,j}\right)\right)

然后可以用线段树合并来转移。节点 uu 的线段树的节点维护区间内 Du,iD_{u,i} 的和,合并的时候维护当前区间左边的和与当前区间右边的和,打上 lazytag 即可。

// Problem: P5298 [PKUWC2018]Minimax
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P5298
// Memory Limit: 500 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <iostream>
#include <algorithm>
#include <cstdio>

using namespace std;

const int MS=300005,S=5000005;
const int p=998244353,inv=796898467;

struct node
{
	int sm,mu;
	int ls,rs;
}tr[S];

int n,a[MS];
int m,b[MS];
int ls[MS],rs[MS];
int cnt,rt[MS];

inline void adtg(int u,int val)
{
	tr[u].sm=1ll*tr[u].sm*val%p,tr[u].mu=1ll*tr[u].mu*val%p;
}

inline void dntg(int u)
{
	adtg(tr[u].ls,tr[u].mu),adtg(tr[u].rs,tr[u].mu),tr[u].mu=1;
}

inline void upda(int u)
{
	tr[u].sm=(tr[tr[u].ls].sm+tr[tr[u].rs].sm)%p;
}

void upd(int &u,int l,int r,int pos,int val)
{
	if(u==0) tr[u=++cnt]=(node){0,1,0,0};
	if(l==r)
	{
		tr[u].sm=val;
		return;
	}
	dntg(u);
	int mid=l+r>>1;
	if(pos<=mid) upd(tr[u].ls,l,mid,pos,val);
	else upd(tr[u].rs,mid+1,r,pos,val);
	upda(u);
}

int que(int u,int l,int r,int pos)
{
	if(l==r) return tr[u].sm;
	dntg(u);
	int mid=l+r>>1;
	if(pos<=mid) return que(tr[u].ls,l,mid,pos);
	else return que(tr[u].rs,mid+1,r,pos);
}

int meg(int x,int y,int lftx,int rigx,int lfty,int rigy,int u)
{
	if(x==0&&y==0) return 0;
	if(y==0)
	{
		int val=(1ll*a[u]*lftx%p+1ll*(1-a[u]+p)*rigx%p)%p;
		return adtg(x,val),x;
	}
	if(x==0)
	{
		int val=(1ll*a[u]*lfty%p+1ll*(1-a[u]+p)*rigy%p)%p;
		return adtg(y,val),y;
	}
	dntg(x),dntg(y);
	int xlsm=tr[tr[x].ls].sm,xrsm=tr[tr[x].rs].sm;
	int ylsm=tr[tr[y].ls].sm,yrsm=tr[tr[y].rs].sm;
	tr[x].ls=meg(tr[x].ls,tr[y].ls,lftx,(rigx+yrsm)%p,lfty,(rigy+xrsm)%p,u);
	tr[x].rs=meg(tr[x].rs,tr[y].rs,(lftx+ylsm)%p,rigx,(lfty+xlsm)%p,rigy,u);
	upda(x);
	return x;
}

void dfs(int u)
{
	if(ls[u]!=0) dfs(ls[u]);
	if(rs[u]!=0) dfs(rs[u]);
	if(ls[u]==0&&rs[u]==0) upd(rt[u],1,m,a[u],1);
	else if(ls[u]==0||rs[u]==0) rt[u]=rt[ls[u]]+rt[rs[u]];
	else rt[u]=meg(rt[ls[u]],rt[rs[u]],0,0,0,0,u);
}

int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
	{
		int fa;
		scanf("%d",&fa);
		if(fa!=0)
		{
			if(ls[fa]==0) ls[fa]=i;
			else rs[fa]=i;
		}
	}
	for(int i=1;i<=n;i++)
	{
		scanf("%d",&a[i]);
		if(ls[i]!=0||rs[i]!=0) a[i]=1ll*a[i]*inv%p;
		else b[++m]=a[i];
	}
	sort(b+1,b+m+1);
	for(int i=1;i<=n;i++) if(ls[i]==0&&rs[i]==0) a[i]=lower_bound(b+1,b+m+1,a[i])-b;
	dfs(1);
	int ans=0;
	for(int i=1;i<=m;i++)
	{
		int Di=que(rt[1],1,m,i);
		ans=(ans+1ll*i*b[i]%p*Di%p*Di%p)%p;
	}
	printf("%d\n",ans);
	return 0;
}