首先有暴力转移:
然后可以用线段树合并来转移。节点 的线段树的节点维护区间内 的和,合并的时候维护当前区间左边的和与当前区间右边的和,打上 lazytag 即可。
// Problem: P5298 [PKUWC2018]Minimax
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P5298
// Memory Limit: 500 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <iostream>
#include <algorithm>
#include <cstdio>
using namespace std;
const int MS=300005,S=5000005;
const int p=998244353,inv=796898467;
struct node
{
int sm,mu;
int ls,rs;
}tr[S];
int n,a[MS];
int m,b[MS];
int ls[MS],rs[MS];
int cnt,rt[MS];
inline void adtg(int u,int val)
{
tr[u].sm=1ll*tr[u].sm*val%p,tr[u].mu=1ll*tr[u].mu*val%p;
}
inline void dntg(int u)
{
adtg(tr[u].ls,tr[u].mu),adtg(tr[u].rs,tr[u].mu),tr[u].mu=1;
}
inline void upda(int u)
{
tr[u].sm=(tr[tr[u].ls].sm+tr[tr[u].rs].sm)%p;
}
void upd(int &u,int l,int r,int pos,int val)
{
if(u==0) tr[u=++cnt]=(node){0,1,0,0};
if(l==r)
{
tr[u].sm=val;
return;
}
dntg(u);
int mid=l+r>>1;
if(pos<=mid) upd(tr[u].ls,l,mid,pos,val);
else upd(tr[u].rs,mid+1,r,pos,val);
upda(u);
}
int que(int u,int l,int r,int pos)
{
if(l==r) return tr[u].sm;
dntg(u);
int mid=l+r>>1;
if(pos<=mid) return que(tr[u].ls,l,mid,pos);
else return que(tr[u].rs,mid+1,r,pos);
}
int meg(int x,int y,int lftx,int rigx,int lfty,int rigy,int u)
{
if(x==0&&y==0) return 0;
if(y==0)
{
int val=(1ll*a[u]*lftx%p+1ll*(1-a[u]+p)*rigx%p)%p;
return adtg(x,val),x;
}
if(x==0)
{
int val=(1ll*a[u]*lfty%p+1ll*(1-a[u]+p)*rigy%p)%p;
return adtg(y,val),y;
}
dntg(x),dntg(y);
int xlsm=tr[tr[x].ls].sm,xrsm=tr[tr[x].rs].sm;
int ylsm=tr[tr[y].ls].sm,yrsm=tr[tr[y].rs].sm;
tr[x].ls=meg(tr[x].ls,tr[y].ls,lftx,(rigx+yrsm)%p,lfty,(rigy+xrsm)%p,u);
tr[x].rs=meg(tr[x].rs,tr[y].rs,(lftx+ylsm)%p,rigx,(lfty+xlsm)%p,rigy,u);
upda(x);
return x;
}
void dfs(int u)
{
if(ls[u]!=0) dfs(ls[u]);
if(rs[u]!=0) dfs(rs[u]);
if(ls[u]==0&&rs[u]==0) upd(rt[u],1,m,a[u],1);
else if(ls[u]==0||rs[u]==0) rt[u]=rt[ls[u]]+rt[rs[u]];
else rt[u]=meg(rt[ls[u]],rt[rs[u]],0,0,0,0,u);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
int fa;
scanf("%d",&fa);
if(fa!=0)
{
if(ls[fa]==0) ls[fa]=i;
else rs[fa]=i;
}
}
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
if(ls[i]!=0||rs[i]!=0) a[i]=1ll*a[i]*inv%p;
else b[++m]=a[i];
}
sort(b+1,b+m+1);
for(int i=1;i<=n;i++) if(ls[i]==0&&rs[i]==0) a[i]=lower_bound(b+1,b+m+1,a[i])-b;
dfs(1);
int ans=0;
for(int i=1;i<=m;i++)
{
int Di=que(rt[1],1,m,i);
ans=(ans+1ll*i*b[i]%p*Di%p*Di%p)%p;
}
printf("%d\n",ans);
return 0;
}