oj地址:wyh的考核
贴下题,好让搜索引擎搜集。
问题描述
wyh非常喜欢lol游戏,一天,他听说学校要选拔lol队员,他非常高兴的就去了,选拔规则是,一共有N个评委,每个评委根据线上对线能力和打团能力给出一个[0,M]之间的一个整数作为分数,然后取平均值,wyh学长非常好奇,他想知道有多少种这样的情况:
平均分等于其中某一位评委给的分数
例如2个评委,给打的分数都是1分,那么此时平均分是1分,即等于第一个评委的分数,又等于第二个评委的分数,这样答案是2
但是由于每个评委打的分都是在[0,M]之间,所以会有很多种情况。
现在请你帮助你们wyh学长数一下有多少种这样的情况,由于结果会很大,请你对1000000007取余
解析
dp定义
dp[i][j]表示i个人,j分的情况下,有多少种情况符合题意。
res定位
则答案res为dp[n-1][i*(n-1)],(i的范围是0~m),相加再乘以n即可。
进一步解释res,对于任意一个dp[n-1][i*(n-1)],表示有n-1个人,平均分为i是,有多少种情况符合题意,此时再增加一个人,只要这个人的分给成i,则就可以复用dp[n-1][i*(n-1)]的值,又根据排列组合可知,要乘以n(因为n-1的意义不是代表前n-1个人,而是任意选择n-1个人,也就是c(n,n-1)=n)。
状态转移推导:
对于dp[i][j],只需要枚举当前人能给出的分数即可(0~m),当前人给分为k时,可以得到dp[i-1][j-k]。挨个相加即可。
dp顺序:
dp[i][j]依赖于dp[i-1][j-k],可知,顺序为从上到下,从左到右。第一行为base case.
填写base case:
回归到题意,只有dp[0][0]是等于1的,其余情况都不存在,都为0。
初步代码如下:
import java.util.*;
public class Main {
static int mod=1000000007;
static long[][] dp;
public static void main(String[] args) {
Scanner sc=new Scanner(System.in);
while (sc.hasNext()){
int T=sc.nextInt();
while(T-->0){
dp=new long[61][12005];
int n=sc.nextInt();
int m=sc.nextInt();
dp[0][0]=1;
for(int i=1;i<=n;i++){
for(int j=0;j<=i*m;j++){
for(int k=0;k<=Math.min(j,m);k++){
dp[i][j]+=dp[i-1][j-k];
dp[i][j]=dp[i][j]%mod;
}
}
}
long res=0;
for(int i=0;i<=m;i++){
res=(res+dp[n-1][i*(n-1)])%mod;
}
System.out.println(res*n%mod);
}
}
}
}
此时虽然答案完全正确了,但是还是AC不了的,需要优化一下,也就是用辅助数据优化前n项和,定义sum[i][j]的意义为,dp表上第i行的前j项和是多少。然后发现,dp[i]永远只依赖dp[i-1],所以不需要sum[i][j],优化成一维的sum[i],复用就行。
优化代码如下(可以AC):
import java.util.*;
public class Main {
static int mod=1000000007;
static long[][] dp;
public static void main(String[] args) {
Scanner sc=new Scanner(System.in);
while (sc.hasNext()){
int T=sc.nextInt();
while(T-->0){
dp=new long[61][12005];
long[] sum=new long[12005];
Arrays.fill(sum,1);
int n=sc.nextInt();
int m=sc.nextInt();
dp[0][0]=1;
for(int i=1;i<=n;i++){
for(int j=0;j<=i*m;j++){
if(j>m){
dp[i][j]+=(sum[j]-sum[j-m-1]+mod)%mod;
}else{
dp[i][j]+=sum[j];
}
dp[i][j]=dp[i][j]%mod;
}
sum[0]=dp[i][0];
//下次循环要用,所以要小于(i+1)*m
for(int j=1;j<=(i+1)*m;j++){
sum[j]=(dp[i][j]+sum[j-1])%mod;
}
}
long res=0;
for(int i=0;i<=m;i++){
res=(res+dp[n-1][i*(n-1)])%mod;
}
System.out.println(res*n%mod);
}
}
}
}