动态 DP 学习笔记

动态 dp,其实是一个用线段树+矩阵乘法来实现带修快速对一段区间进行 dp 的算法。

Part 1 序列上动态 DP

主体思路

首先学过矩阵快速幂优化递推的同学肯定都知道,矩阵乘法本质上可以是 dp 的转移

那么我们就可以在构造出状态 i1i-1\to 状态 ii 的转移矩阵后,使用线段树维护这些矩阵,实现快速进行 dp

但是有个问题,有些 dp 需要最大值和最小值操作。这时,我们就可以重新定义一下矩阵乘法了:

Ci,j=maxkAi,k+Bk,jC_{i,j}=\max\limits_{k} A_{i,k}+B_{k,j}

或者

Ci,j=minkAi,k+Bk,jC_{i,j}=\min\limits_{k} A_{i,k}+B_{k,j}

稍加思考可以发现,这样定义矩阵乘法,可以方便大多数 dp 的转移:

  • 若某一项需要加上 xx,那么对应位置设为 xx

  • 若某一项不能算进答案,那么对应位置设为 inf/inf-inf/inf

而且它满足结合律

但是如果某一项需要乘上系数,dp 转移还需要最大/最小值,那么动态 dp 就无法解决了

由于线段树的操作是 O(logn)O(\log n) 的,所以动态 dp 的单次时间复杂度是 O(logn)O(\log n)

例题讲解

SP1043 GSS1 - Can you answer these queries I

可以发现,这道题和小白逛公园十分类似,但我们尝试使用动态 dp 求解。

考虑朴素的 dp,定义 fif_i 为以 ii 结尾的最大子段和,gi=maxjifjg_i=\max\limits_{j\le i} f_j

那么易得转移方程:

{fi=max(fi1+ai,ai)gi=max(gi1,fi)\begin{cases}f_i=\max(f_{i-1}+a_i,a_i)\\g_i=\max(g_{i-1},f_i)\end{cases}

接下来考虑转移矩阵的构造。

设当前已经求出了 fif_igig_i,要求 fi+1f_{i+1}gi+1g_{i+1}

那么设当前矩阵为 [fi0gi]\begin{bmatrix}f_i&0&g_i\end{bmatrix},则可以构造转移矩阵:

[fi0gi][aiinfaiai0aiinfinf0]=[fi+10gi+1]\begin{bmatrix}f_i&0&g_i\end{bmatrix}\cdot\begin{bmatrix}a_i&-inf&a_i\\a_i&0&a_i\\-inf&-inf&0\end{bmatrix}=\begin{bmatrix}f_{i+1}&0&g_{i+1}\end{bmatrix}

注意此时的矩阵乘法是重新定义过的:Ci,j=maxkAi,k+Bk,jC_{i,j}=\max\limits_{k} A_{i,k}+B_{k,j}

推出转移矩阵后,我们只要造一棵线段树来维护就好了。

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const long long MS=50005;

struct mrt
{
	int n,m;
	long long dat[15][15];
}tree[MS<<4];

int n,m;
long long a[MS];

inline long long read()
{
	long long s=0,w=1,ch=getchar();
	while(ch<'0'||ch>'9') ch=='-'?w=-1,ch=getchar():ch=getchar();
	while(ch>='0'&&ch<='9') s=(s<<1)+(s<<3)+(ch^48),ch=getchar();
	return s*w;
}

inline mrt mrtmul(mrt x,mrt y)
{
	if(x.n==-1&&x.m==-1)
	{
		return y;
	}
	if(y.n==-1&&y.m==-1)
	{
		return x;
	}
	mrt res;
	res.n=x.n;
	res.m=y.m;
	for(int i=1;i<=res.n;i++)
	{
		for(int j=1;j<=res.m;j++)
		{
			res.dat[i][j]=-1e17;
		}
	}
	for(int i=1;i<=x.n;i++)
	{
		for(int j=1;j<=y.m;j++)
		{
			for(int k=1;k<=x.m;k++)
			{
				res.dat[i][j]=max(res.dat[i][j],x.dat[i][k]+y.dat[k][j]);
			}
		}
	}
	return res;
}

inline void updata(int u)
{
	tree[u]=mrtmul(tree[u<<1],tree[u<<1|1]);
}

void build(int u,int l,int r)
{
	if(l==r)
	{
		tree[u].n=3;
		tree[u].m=3;
		tree[u].dat[1][1]=a[l];
		tree[u].dat[1][2]=-1e17;
		tree[u].dat[1][3]=a[l];
		tree[u].dat[2][1]=a[l];
		tree[u].dat[2][2]=0;
		tree[u].dat[2][3]=a[l];
		tree[u].dat[3][1]=-1e17;
		tree[u].dat[3][2]=-1e17;
		tree[u].dat[3][3]=0;
		return;
	}
	int mid=l+r>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
	updata(u);
}

void que(int u,int l,int r,int L,int R,mrt& res)
{
	if(r<L||l>R)
	{
		return;
	}
	if(l>=L&&r<=R)
	{
		res=mrtmul(res,tree[u]);
		return;
	}
	int mid=l+r>>1;
	if(L<=mid)
	{
		que(u<<1,l,mid,L,R,res);
	}
	if(R>=mid+1)
	{
		que(u<<1|1,mid+1,r,L,R,res);
	}
}

int main()
{
	n=read();
	for(int i=1;i<=n;i++)
	{
		a[i]=read();
	}
	build(1,1,n);
	m=read();
	while(m--)
	{
		int x,y;
		x=read();
		y=read();
		if(x>y)
		{
			swap(x,y);
		}
		if(x==y)
		{
			printf("%lld\n",a[x]);
			continue;
		}
		mrt tmp;
		tmp.n=-1;
		tmp.m=-1;
		que(1,1,n,x+1,y,tmp);
		mrt ans;
		ans.n=1;
		ans.m=3;
		ans.dat[1][1]=a[x];
		ans.dat[1][2]=0;
		ans.dat[1][3]=a[x];
		ans=mrtmul(ans,tmp);
		printf("%lld\n",ans.dat[1][3]);
	}
	return 0;
}

SP1716 GSS3 - Can you answer these queries III

可以发现这题是上题的加强版,动态 dp 的强大之处也在这题体现了。

这题多了一个单点修改的操作,那么只要对于修改的那个位置单独重新构造矩阵即可。

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const long long MS=50005;

struct mrt
{
	int n,m;
	long long dat[15][15];
}tree[MS<<4];

int n,m;
long long a[MS];

inline long long read()
{
	long long s=0,w=1,ch=getchar();
	while(ch<'0'||ch>'9') ch=='-'?w=-1,ch=getchar():ch=getchar();
	while(ch>='0'&&ch<='9') s=(s<<1)+(s<<3)+(ch^48),ch=getchar();
	return s*w;
}

inline mrt mrtmul(mrt x,mrt y)
{
	if(x.n==-1&&x.m==-1)
	{
		return y;
	}
	if(y.n==-1&&y.m==-1)
	{
		return x;
	}
	mrt res;
	res.n=x.n;
	res.m=y.m;
	for(int i=1;i<=res.n;i++)
	{
		for(int j=1;j<=res.m;j++)
		{
			res.dat[i][j]=-1e17;
		}
	}
	for(int i=1;i<=x.n;i++)
	{
		for(int j=1;j<=y.m;j++)
		{
			for(int k=1;k<=x.m;k++)
			{
				res.dat[i][j]=max(res.dat[i][j],x.dat[i][k]+y.dat[k][j]);
			}
		}
	}
	return res;
}

inline void updata(int u)
{
	tree[u]=mrtmul(tree[u<<1],tree[u<<1|1]);
}

void build(int u,int l,int r)
{
	if(l==r)
	{
		tree[u].n=3;
		tree[u].m=3;
		tree[u].dat[1][1]=a[l];
		tree[u].dat[1][2]=-1e17;
		tree[u].dat[1][3]=a[l];
		tree[u].dat[2][1]=a[l];
		tree[u].dat[2][2]=0;
		tree[u].dat[2][3]=a[l];
		tree[u].dat[3][1]=-1e17;
		tree[u].dat[3][2]=-1e17;
		tree[u].dat[3][3]=0;
		return;
	}
	int mid=l+r>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
	updata(u);
}

void upd(int u,int l,int r,int pos,long long val)
{
	if(l==r)
	{
		tree[u].n=3;
		tree[u].m=3;
		tree[u].dat[1][1]=val;
		tree[u].dat[1][2]=-1e17;
		tree[u].dat[1][3]=val;
		tree[u].dat[2][1]=val;
		tree[u].dat[2][2]=0;
		tree[u].dat[2][3]=val;
		tree[u].dat[3][1]=-1e17;
		tree[u].dat[3][2]=-1e17;
		tree[u].dat[3][3]=0;
		return;
	}
	int mid=l+r>>1;
	if(pos<=mid)
	{
		upd(u<<1,l,mid,pos,val);
	}
	else
	{
		upd(u<<1|1,mid+1,r,pos,val);
	}
	updata(u);
}

void que(int u,int l,int r,int L,int R,mrt& res)
{
	if(r<L||l>R)
	{
		return;
	}
	if(l>=L&&r<=R)
	{
		res=mrtmul(res,tree[u]);
		return;
	}
	int mid=l+r>>1;
	if(L<=mid)
	{
		que(u<<1,l,mid,L,R,res);
	}
	if(R>=mid+1)
	{
		que(u<<1|1,mid+1,r,L,R,res);
	}
}

int main()
{
	n=read();
	for(int i=1;i<=n;i++)
	{
		a[i]=read();
	}
	build(1,1,n);
	m=read();
	while(m--)
	{
		int ty;
		ty=read();
		if(ty==0)
		{
			int x=read();
			long long y=read();
			a[x]=y;
			upd(1,1,n,x,y);
		}
		else
		{
			int x=read();
			int y=read();
			if(x>y)
			{
				swap(x,y);
			}
			if(x==y)
			{
				printf("%lld\n",a[x]);
				continue;
			}
			mrt tmp;
			tmp.n=-1;
			tmp.m=-1;
			que(1,1,n,x+1,y,tmp);
			mrt ans;
			ans.n=1;
			ans.m=3;
			ans.dat[1][1]=a[x];
			ans.dat[1][2]=0;
			ans.dat[1][3]=a[x];
			ans=mrtmul(ans,tmp);
			printf("%lld\n",ans.dat[1][3]);
		}
	}
	return 0;
}

CF750E New Year and Old Subsequence

我们可以设 fi,0/1/2/3/4f_{i,0/1/2/3/4} 表示前 ii 个字符,包含 /2/20/201/2017\varnothing/2/20/201/2017 至少需要删除多少个字符,那么转移方程:

{fi,0=fi1,0+[si=2]fi,1=min(fi1,1+[si=0],fi1,0[si=2])fi,2=min(fi1,2+[si=1],fi1,1[si=0])fi,3=min(fi1,3+[si=7si=6],fi1,2[si=1])fi,4=min(fi1,4+[si=6],fi1,3[si=7])\begin{cases}f_{i,0}=f_{i-1,0}+[s_i=2]\\f_{i,1}=\min(f_{i-1,1}+[s_i=0],f_{i-1,0}[s_i=2])\\f_{i,2}=\min(f_{i-1,2}+[s_i=1],f_{i-1,1}[s_i=0])\\f_{i,3}=\min(f_{i-1,3}+[s_i=7\vee s_i=6],f_{i-1,2}[s_i=1])\\f_{i,4}=\min(f_{i-1,4}+[s_i=6],f_{i-1,3}[s_i=7])\end{cases}

构造转移矩阵即可。

代码如下:

#include <iostream>
#include <cstdio>

using namespace std;

const long long MS=200000;
const int inf=1e8;

struct mrt
{
	int n,m;
	int dat[6][6];
}tree[(MS<<2)+5];

int n,m;
char a[MS+5];

inline void mrtmul(mrt &res,mrt x,mrt y)
{
	if(x.n==-1&&x.m==-1)
	{
		res=y;
		return;
	}
	if(y.n==-1&&y.m==-1)
	{
		res=x;
		return;
	}
	res.n=x.n;
	res.m=y.m;
	for(int i=1;i<=res.n;i++)
	{
		for(int j=1;j<=res.m;j++)
		{
			res.dat[i][j]=inf;
		}
	}
	for(int i=1;i<=res.n;i++)
	{
		for(int j=1;j<=res.m;j++)
		{
			for(int k=1;k<=x.m;k++)
			{
				res.dat[i][j]=min(res.dat[i][j],x.dat[i][k]+y.dat[k][j]);
			}
		}
	}
}

inline void updata(int u)
{
	mrtmul(tree[u],tree[u<<1],tree[u<<1|1]);
}

void build(int u,int l,int r)
{
	if(l==r)
	{
		tree[u].n=5;
		tree[u].m=5;
		for(int i=1;i<=5;i++)
		{
			for(int j=1;j<=5;j++)
			{
				tree[u].dat[i][j]=inf;
			}
		}
		tree[u].dat[1][1]=a[l]==2;
		tree[u].dat[1][2]=(a[l]!=2)*inf;
		tree[u].dat[2][2]=a[l]==0;
		tree[u].dat[2][3]=(a[l]!=0)*inf;
		tree[u].dat[3][3]=a[l]==1;
		tree[u].dat[3][4]=(a[l]!=1)*inf;
		tree[u].dat[4][4]=a[l]==6||a[l]==7;
		tree[u].dat[4][5]=(a[l]!=7)*inf;
		tree[u].dat[5][5]=a[l]==6;
		return;
	}
	int mid=l+r>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
	updata(u);
}

void que(int u,int l,int r,int L,int R,mrt &res)
{
	if(r<L||l>R)
	{
		return;
	}
	if(l>=L&&r<=R)
	{
		mrtmul(res,res,tree[u]);
		return;
	}
	int mid=l+r>>1;
	if(L<=mid)
	{
		que(u<<1,l,mid,L,R,res);
	}
	if(R>=mid+1)
	{
		que(u<<1|1,mid+1,r,L,R,res);
	}
}

int main()
{
	scanf("%d%d",&n,&m);
	scanf("%s",a+1);
	for(int i=1;i<=n;i++)
	{
		a[i]-='0';
	}
	build(1,1,n);
	while(m--)
	{
		int l,r;
		scanf("%d%d",&l,&r);
		if(l>r)
		{
			swap(l,r);
		}
		if(r-l+1<4)
		{
			puts("-1");
			continue;
		}
		mrt tmp;
		tmp.n=-1;
		tmp.m=-1;
		que(1,1,n,l+1,r,tmp);
		mrt ans;
		ans.n=1;
		ans.m=5;
		ans.dat[1][1]=a[l]==2;
		ans.dat[1][2]=(a[l]!=2)*inf;
		ans.dat[1][3]=inf;
		ans.dat[1][4]=inf;
		ans.dat[1][5]=inf;
		mrtmul(ans,ans,tmp);
		printf("%d\n",ans.dat[1][5]>n?-1:ans.dat[1][5]);
	}
	return 0;
}

练习题目

Part 2 树上动态 DP

看一道例题:P4719 【模板】"动态 DP"&动态树分治

给你一棵树,点有点权,带修,每次修改完输出最大独立集的大小。

我们设 fu,0/1f_{u,0/1} 表示点 uu 没选/选了时 uu 的子树的最大收益,那么显然有:

{fu,0=jmax(fj,0,fj,1)fu,1=jfj,0\begin{cases}f_{u,0}=\sum\limits_{j}\max(f_{j,0},f_{j,1})\\f_{u,1}=\sum\limits_{j}f_{j,0}\end{cases}

其中 jjii 的一个儿子。

但是当你尝试构造转移矩阵时,你就会发现这是一对多的转移,不好维护,所以我们可以跑一下树剖,并且引入一个 gu,0/1g_{u,0/1} 表示点 uu 的所有轻儿子可选可不选/都不选时的最大收益,那么有转移:

{gu,0=jmax(fj,0,fj,1)gu,1=au+jfj,0\begin{cases}g_{u,0}=\sum\limits_{j}\max(f_{j,0},f_{j,1})\\g_{u,1}=a_u+\sum\limits_{j}f_{j,0}\end{cases}

其中 jjuu 的一个轻儿子。

{fu,0=max(fv,0,fv,1)+gu,0fu,1=fv,0+gu,1\begin{cases}f_{u,0}=\max(f_{v,0},f_{v,1})+g_{u,0}\\f_{u,1}=f_{v,0}+g_{u,1}\end{cases}

其中 vvuu 的重儿子。

这样可以构造转移矩阵了:(其中的矩阵乘法是重新定义的Ci,j=maxkAi,k+Bk,jC_{i,j}=\max\limits_{k} A_{i,k}+B_{k,j}

[gu,0gu,0gu,1inf][fv,0fv,1]=[fu,0fu,1]\begin{bmatrix}g_{u,0}&g_{u,0}\\g_{u,1}&-\inf\end{bmatrix}\cdot\begin{bmatrix}f_{v,0}\\f_{v,1}\end{bmatrix}=\begin{bmatrix}f_{u,0}\\f_{u,1}\end{bmatrix}

这么构造是因为跑完树剖之后求一段区间矩阵的乘积是从深度小的往深度大的乘过去

考虑怎么修改,假设当前要修改 uu,那么我们可以求出 uu 所在的重链在修改 uu 前的答案 frt0/1frt_{0/1},然后修改 uu,接下来求出修改 uu 后的答案 lst0/1lst_{0/1}

然后,我们考虑修改 uu 对它所属重链顶端的节点的父亲的贡献,设 uu 所属重链顶端的节点的父亲为 pppp 的父节点为 qq那么显然 ppqq 的轻儿子,那么我们令

gq,0=gq,0max(frt0,frt1)+max(lst0,lst1)gq,1=gq,1frt0+lst0g_{q,0}=g_{q,0}-\max(frt_0,frt_1)+\max(lst_0,lst_1)\text{,}g_{q,1}=g_{q,1}-frt_0+lst_0

然后又会发现,这个是关于 qq 和它的重链的修改。所以我们不断地跳、修改即可

代码如下:

#include <iostream>
#include <cstdio>

using namespace std;

const long long S=200005,MS=100005;
const int inf=50000005;

struct mrt
{
	int n,m;
	int dat[4][4];
}tree[MS<<2];

int n,m;
int val[MS];
int esum,to[S],nxt[S],h[MS];
int f[MS][2],g[MS][2];
int fat[MS],siz[MS],hson[MS];
int cnt,id[MS],vl[MS],a[MS][2],top[MS],depp[MS];
mrt b[MS];

inline void add(int x,int y)
{
	to[++esum]=y;
	nxt[esum]=h[x];
	h[x]=esum;
}

void dfs1(int u,int fa)
{
	f[u][1]=val[u];
	fat[u]=fa;
	siz[u]=1;
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa)
		{
			continue;
		}
		dfs1(v,u);
		f[u][0]+=max(f[v][0],f[v][1]);
		f[u][1]+=f[v][0];
		siz[u]+=siz[v];
		if(siz[v]>siz[hson[u]])
		{
			hson[u]=v;
		}
	}
}

void dfs2(int u,int fa,int tpf)
{
	g[u][1]=val[u];
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa||v==hson[u])
		{
			continue;
		}
		g[u][0]+=max(f[v][0],f[v][1]);
		g[u][1]+=f[v][0];
	}
	id[u]=++cnt;
	vl[cnt]=val[u];
	a[cnt][0]=g[u][0];
	a[cnt][1]=g[u][1];
	top[u]=tpf;
	depp[tpf]=cnt;
	if(hson[u]!=0)
	{
		dfs2(hson[u],u,tpf);
	}
	for(int i=h[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==fa||v==hson[u])
		{
			continue;
		}
		dfs2(v,u,v);
	}
}

inline void mrtmul(mrt& res,mrt x,mrt y)
{
	if(x.n==-1&&x.m==-1)
	{
		res=y;
		return;
	}
	if(y.n==-1&&y.m==-1)
	{
		res=x;
		return;
	}
	res.n=x.n;
	res.m=y.m;
	for(int i=1;i<=res.n;i++)
	{
		for(int j=1;j<=res.m;j++)
		{
			res.dat[i][j]=-inf;
		}
	}
	for(int i=1;i<=res.n;i++)
	{
		for(int j=1;j<=res.m;j++)
		{
			for(int k=1;k<=x.m;k++)
			{
				res.dat[i][j]=max(res.dat[i][j],x.dat[i][k]+y.dat[k][j]); 
			}
		}
	}
}

inline void updata(int u)
{
	mrtmul(tree[u],tree[u<<1],tree[u<<1|1]);
}

void build(int u,int l,int r)
{
	if(l==r)
	{
		tree[u].n=2;
		tree[u].m=2;
		tree[u].dat[1][1]=a[l][0];
		tree[u].dat[1][2]=a[l][0];
		tree[u].dat[2][1]=a[l][1];
		tree[u].dat[2][2]=-inf;
		b[l]=tree[u];
		return;
	}
	int mid=l+r>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
	updata(u);
}

void que(int u,int l,int r,int L,int R,mrt& res)
{
	if(r<L||l>R)
	{
		return;
	}
	if(l>=L&&r<=R)
	{
		mrtmul(res,res,tree[u]);
		return;
	}
	int mid=l+r>>1;
	if(L<=mid)
	{
		que(u<<1,l,mid,L,R,res);
	}
	if(R>=mid+1)
	{
		que(u<<1|1,mid+1,r,L,R,res);
	}
}

void upd(int u,int l,int r,int pos)
{
	if(l==r)
	{
		tree[u]=b[l];
		return;
	}
	int mid=l+r>>1;
	if(pos<=mid)
	{
		upd(u<<1,l,mid,pos);
	}
	else
	{
		upd(u<<1|1,mid+1,r,pos);
	}
	updata(u);
}

inline void updnode(int u,int k)
{
	b[id[u]].dat[2][1]+=k-vl[id[u]];
	vl[id[u]]=k;
	while(u!=0)
	{
		mrt frt,lst;
		frt.n=-1;
		frt.m=-1;
		lst.n=-1;
		lst.m=-1;
		que(1,1,n,id[top[u]],depp[top[u]],frt);
		upd(1,1,n,id[u]);
		que(1,1,n,id[top[u]],depp[top[u]],lst);
		u=fat[top[u]];
		b[id[u]].dat[1][1]+=max(lst.dat[1][1],lst.dat[2][1])-max(frt.dat[1][1],frt.dat[2][1]);
		b[id[u]].dat[1][2]=b[id[u]].dat[1][1];
		b[id[u]].dat[2][1]+=lst.dat[1][1]-frt.dat[1][1];
	}
}


int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		scanf("%d",&val[i]);
	}
	for(int i=1;i<=n-1;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(1,0);
	dfs2(1,0,1);
	build(1,1,n);
	while(m--)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		updnode(x,y);
		mrt ans;
		ans.n=-1;
		ans.m=-1;
		que(1,1,n,id[top[1]],depp[top[1]],ans);
		printf("%d\n",max(ans.dat[1][1],ans.dat[2][1]));
	}
	return 0;
}

练习题目