【題目鏈接】
http://acm.hust.edu.cn/vjudge/problem/viewProblem.action?id=19461
【解題報告】
題目大意:給你一個N元序列,每個元有一個分數,每次可以從左邊拿或者從右邊拿任意多。A,B輪流拿,拿完爲止,求問A最多比B多拿多少。其中N<=100.
因爲兩個人都足夠聰明,因此當他們任意一人面臨一個(i,j)的局面時,因爲序列總和sum[j]-sum[i-1]是不變的,所以一定會選擇拿完之後另一個人面臨局面(k,l)得分最少。這是類似於博弈論的思想。
所以我們這樣設計dp狀態:
dp[i][j]表示面臨(i,j)局面的人最多可以拿多少分
那麼它可以轉移到
S={ dp(i+1,j), dp(i+2,j) … dp(j,j) , dp(i,j-1) , dp(i,j-2) … , dp(i,i) , 0 }
一口氣拿完時,另一個人面臨的局面就是0.
所以狀態轉移方程就是dp[i][j]=sum-min;
DFS區間更新即可。
需要注意的是這樣做的時間複雜度是O(n^3),仍然有很大的優化空間。
優化:
對於min{ dp(i+1,j), dp(i+2,j) ... dp(j,j) , dp(i,j-1) , dp(i,j-2) ... , dp(i,i) }
設f(i+1,j)=min{ dp(i+1,j), dp(i+2,j) ... dp(j,j) }
設g(i.j-1)=min{ dp(i,j-1) , dp(i,j-2) ... , dp(i,i) }
那麼dp[i][j]=sum-min{ f(i+1,j),g(i,j-1),0 }
其中f和g均可以通過遞推得出:
f[i][j]=min{ f[i+1][j],dp[i][j] }
g[i][j]=min{ g[i][j-1],dp[i][j] }
所以時間複雜度被降到了O(n^2)
【參考代碼】
1.O(N^3)
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<stack>
#include<queue>
#include<vector>
#include<map>
using namespace std;
const int INF=2e9+1e8;
int n;
int dp[100+10][100+10];
int a[100+10],sum[100+10];
int DFS( int i ,int j )
{
if( dp[i][j]!=INF )return dp[i][j];
int minn=INF;
for( int ii=i+1; ii<=j; ii++ )minn=min( minn, DFS( ii,j ) );
for( int jj=j-1; jj>=i; jj-- )minn=min( minn, DFS( i,jj ) );
minn=min( minn,0 ); //i~j全部拿完的狀況
return dp[i][j]=sum[j]-sum[i-1]-minn;
}
int main()
{
while( ~scanf("%d",&n) && n )
{
for( int i=1; i<=n; i++ )scanf("%d",&a[i]);
for( int i=1; i<=n; i++ )
for( int j=1; j<=n; j++ )
{
if( i==j )dp[i][j]=a[i];
else dp[i][j]=INF;
}
memset(sum,0,sizeof(sum));
for( int i=1; i<=n; i++ )sum[i]=sum[i-1]+a[i];
// dp[i][j]=sum(i,j)-min( dp[i+1][j], dp[i+2][j],...dp[j][j], dp[i][j-1],...,dp[i][i] );
printf( "%d\n",2*DFS(1,n)-sum[n] );
}
return 0;
}
2.O(N^2)
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<stack>
#include<queue>
#include<vector>
#include<map>
using namespace std;
const int INF=2e9+1e8;
int n;
int dp[100+10][100+10],f[100+10][100+10],g[100+10][100+10];
int a[100+10],sum[100+10];
int main()
{
while( ~scanf("%d",&n) && n )
{
for( int i=1; i<=n; i++ )scanf("%d",&a[i]);
memset(sum,0,sizeof sum);
for( int i=1; i<=n; i++ )sum[i]=sum[i-1]+a[i];
memset(f,0,sizeof f);
memset(g,0,sizeof(g));
for( int i=n; i>=1; i-- )
for( int j=i; j<=n; j++ )
{
int temp=min( f[i+1][j],g[i][j-1] );
temp=min(0,temp);
dp[i][j]=sum[j]-sum[i-1]-temp;
f[i][j]=min( f[i+1][j],dp[i][j] );
g[i][j]=min( g[i][j-1],dp[i][j] );
}
printf( "%d\n",2*dp[1][n]-sum[n] );
}
return 0;
}