【DP】区间DP入门

在开始以前我要感谢y总,是他精彩的讲解才让我对区间DP有较深的认识。c++

简介

通常是线性结构上的对区间进行求解最值,计数的动态规划。大体思路是枚举断点,而后对断点两边求取最优解,而后进行合并从而得解。spa

原理

结合模板题(合并石子)讲述:https://www.acwing.com/problem/content/284/code

由于题目具备合并相邻物品的性质,因此在合并的过程当中,必然会在最后一步出现两个物品合二为一的状况,而这两个物品则是分别由左侧的物品、右侧的物品合并而来的。 所以,咱们的思路是枚举最后一步合并两个物品时候的断点(记为 \(k\) ),为了方便起见,咱们能够将断点放在某个物品上面。ci

结合样例具体来讲:get

k
1 3 5 2
k
1 3 5 2
k
1 3 5 2
k
1 3 5 2

上面即是四个断点。it


对于本题,咱们记f[l][r]为合并 \([l,r]\) 的物品所能获得的最小贡献。
而断点将 \([l,r]\) 分为了 \([l,k],[k+1,r]\) ,这两个区间的贡献分别是 f[l][k],f[k+1][r] 而合并这两个区间的贡献则是 sum(l,r)) (其中sum(l,r) 表示 \([l,r]\) 的物品的权值和)模板

从而获得递推方程式: f[l][r] = min(f[l][r],f[l][k]+f[k+1][r]+sum(l,r))class

能够看出,在枚举断点的过程当中,咱们已经覆盖了全部状况(根据断点全部可能位置分类),所以这样作可以保证获得答案。原理

至此,在思惟上不会有太大困难。二叉树

下面讲一下怎么用递推的方法求解:

由本题的逻辑结构可知,咱们要先处理出小区间的 \(f值\) 才可以保证大区间能够获得更新,因此咱们第一重循环枚举的是区间的长度len,下面的部分则是枚举起点(即 l), 结合长度咱们能够获得 r = l+len-1 ,进而咱们获得了相应的区间 \([l,r]\) ,接下来枚举断点 \(k\) 便可。

结合代码理解:

#include<bits/stdc++.h>
using namespace std;

const int INF=0x3f3f3f3f;
const int N=305;

int f[N][N];
int w[N],s[N];
int n;
int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        s[i]=s[i-1]+w[i];
    }
    
    for(int len=1;len<=n;len++)
        for(int l=1;l+len-1<=n;l++){
            int r=l+len-1;
            if(len==1){
                f[l][r]=0;
            }else{
                f[l][r]=INF;
                for(int k=l;k<r;k++)
                    f[l][r]=min(f[l][r],f[l][k]+f[k+1][r]+s[r]-s[l-1]);
            }
        }
    cout<<f[1][n]<<endl;
    
    return 0;
}

固然,也能够采起记忆化搜索,这样不须要考虑太多。

例题

环形石子合并:https://www.acwing.com/activity/content/problem/content/1297/1/

分析

这题无非是将上题排成一列的物品放在了环上,所以咱们能够采起断环成链的技巧:
显然,合并 \(n\) 个物品须要 \(n-1\) 步,所以,必然存在两个物品,它们并无进行合并,那么它们之间便出现了“断边”,这样的“断边”并不会参与到合并的过程当中,问题便由环转化为链的状况,因此咱们只需枚举“断边”,而后进行求解便可。

有一个技巧:只需将原有的物品再按顺序“复制”一份,分别获得区间:

对于样例:

4 5 9 4

复制:

4 5 9 4 4 5 9 4

而后依次把区间(记为 \([s,t]\) )取出求解:

s     t
4 5 9 4 4 5 9 4
s     t
4 5 9 4 4 5 9 4
s     t
4 5 9 4 4 5 9 4
s     t
4 5 9 4 4 5 9 4

(最后一个复制的元素是没用的,能够忽略)

这样分别求解四个子问题就好了。

代码:

#include<bits/stdc++.h>
using namespace std;

#define INF 0x3f3f3f3f

const int N = 410;

int f[N][N],g[N][N];
int s[N],w[N];
int n;

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        w[i+n]=w[i];
    }
    
    memset(f,0x3f,sizeof f);
    memset(g,0xcf,sizeof g);
    
    for(int i=1;i<=2*n;i++) s[i]=s[i-1]+w[i];
    
    for(int len=1;len<=n;len++){
        for(int l=1;l+len-1<=n*2;l++){
            int r=l+len-1;
            
            if(len==1) f[l][r]=g[l][r]=0;
            else{
                for(int k=l;k<r;k++){
                    f[l][r]=min(f[l][r],f[l][k]+f[k+1][r]+s[r]-s[l-1]);
                    g[l][r]=max(g[l][r],g[l][k]+g[k+1][r]+s[r]-s[l-1]);
                }
            }
                
        }
    }
    
    int maxv=-INF,minv=INF;
    for(int i=1;i<=n;i++){
        maxv=max(maxv,g[i][i+n-1]);
        minv=min(minv,f[i][i+n-1]);
    }
    
    cout<<minv<<endl<<maxv<<endl;
    
    return 0;
}

记忆化搜索版本:(比较久以前写的emm)

#include<bits/stdc++.h>
using namespace std;
#define maxn 101
int n;
int a[maxn<<1];
int f_max[maxn][maxn];
int f_min[maxn][maxn];
int rec[maxn];
int s[maxn];

int sum(int l,int r){
    return s[r]-s[l-1];
}

int dfs_max(int l,int r){
    if(l==r) return f_max[l][r]=0;
    if(f_max[l][r]) return f_max[l][r];

    int res=0;
    for(int k=l;k+1<=r;k++){
        res=max(res,dfs_max(l,k)+dfs_max(k+1,r)+sum(l,r));
    }
    return f_max[l][r]=res;
}

int dfs_min(int l,int r){
    if(l==r) return f_min[l][r]=0;
    if(f_min[l][r]) return f_min[l][r];

    int res=INT_MAX;
    for(int k=l;k+1<=r;k++){
        res=min(res,dfs_min(l,k)+dfs_min(k+1,r)+sum(l,r));
    }
    return f_min[l][r]=res;
}

int main(){
    cin>>n;
    for(int i=1;i<=n-1;i++) cin>>a[i],a[i+n]=a[i];
    cin>>a[n];

    int rec_max=0;
    int rec_min=INT_MAX;

    for(int st=1;st<=n;st++){
        memset(rec,0,sizeof(rec));
        memset(s,0,sizeof(s));
        memset(f_max,0,sizeof(f_max));
        memset(f_min,0,sizeof(f_min));
        for(int i=st;i<=st+n-1;i++) rec[i-st+1]=a[i];

        s[1]=rec[1];
        for(int i=2;i<=n;i++) s[i]=s[i-1]+rec[i];

        rec_max=max(rec_max,dfs_max(1,n));
        rec_min=min(rec_min,dfs_min(1,n));
    }
    cout<<rec_min<<endl;
    cout<<rec_max<<endl;
    return 0;
}

能量项链:https://www.acwing.com/problem/content/322/

分析
和上面题目相似(事实上区间DP的题都差很少),要注意理解是如何合并珠子的。

代码:

#include<bits/stdc++.h>
using namespace std;

const int N=105;

int n;
int w[N<<1];
int f[N<<1][N<<1];

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        w[n+i]=w[i];
    }
    
    for(int len=3;len<=n+1;len++)
        for(int l=1;l+len-1<=2*n;l++){
            int r=l+len-1;
            for(int k=l+1;k<=r-1;k++)
                f[l][r]=max(f[l][r],f[l][k]+f[k][r]+w[l]*w[k]*w[r]);
        }
        
    int res=0;
    for(int i=1;i<=n;i++) res=max(res,f[i][i+n]);
    
    cout<<res<<endl;
    
    return 0;
}

记忆化搜索版本:

#include<bits/stdc++.h>
using namespace std;

const int N=210;

int n;
int w[N];
int f[N][N];

int dp(int l,int r){
    if(f[l][r]>=0) return f[l][r];
    if(r==l || r==l+1) return f[l][r]=0;
    
    int &v=f[l][r];
    for(int k=l+1;k<=r-1;k++){
        v=max(v,dp(l,k)+dp(k,r)+w[l]*w[k]*w[r]);
    }
    return v;
}

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        w[n+i]=w[i];
    }
    
    memset(f,-1,sizeof f);
    
    int res=0;
    for(int i=1;i<=n;i++) res=max(res,dp(i,i+n));
    
    cout<<res<<endl;
    
    return 0;
}

加分二叉树:https://www.acwing.com/problem/content/481/

分析

g[l][r] 表示 \([l,r]\) 的根节点。
将中序遍历的序列看做是区间求解,而后枚举根节点(将它做为断点),记录答案的过程当中要注意当答案获得更新的时候才记录这个区间的根节点。

#include<bits/stdc++.h>
using namespace std;

const int N=35;

int f[N][N]; //dp
int g[N][N]; //path

int n;
int w[N];

void dfs(int l,int r){
    if(l>r) return;
    
    int root=g[l][r];
    cout<<root<<' ';
    dfs(l,root-1);
    dfs(root+1,r);
}
int main(){
    cin>>n;
    for(int i=1;i<=n;i++) cin>>w[i];
    
    for(int len=1;len<=n;len++)
        for(int l=1;l+len-1<=n;l++){
            int r=l+len-1;
            if(len==1){
                f[l][r]=w[l];
                g[l][r]=l;
            }
            else{
                for(int k=l;k<=r;k++){
                    int left= k==l?1:f[l][k-1];
                    int right= k==r?1:f[k+1][r];
                    int score=left*right + w[k];
                    if(score>f[l][r]){
                        f[l][r]=score;
                        g[l][r]=k;
                    }
                }
            }
        }
    
    cout<<f[1][n]<<endl;
    dfs(1,n);
    
    return 0;
}
相关文章
相关标签/搜索