題解:
第 i 列的狀態有四種:(黑,黑),(黑,白),(白,黑),(白,白),設爲0(0,0), 1(0,1), 2(1,0), 3(1,1)。
dp[i][k][j]:i 表示第 i 列,k 表示有 k 種,j 表示第 i 列的狀態。
那麼我們可以得到:
dp[i][k][0] = dp[i-1][k][0] + dp[i-1][k][1] + dp[i-1][k][2] + dp[i-1][k-1][3]。
同理:dp[i][k][3] = dp[i-1][k][3] + dp[i-1][k][1] + dp[i-1][k][2] + dp[i-1][k-1][0]。
因爲第 i 列的狀態是純色的,只要上一列中有和這一列一樣的顏色就可以疊加並且 k 不變,如果和上一個完全不一樣,則種類會多加一個,所以要取 k-1 。
dp[i][k][1] = dp[i-1][k][1] + dp[i-1][k-2][2] + dp[i-1][k-1][0] + dp[i-1][k-1][3]。
同理:dp[i][k][2] = dp[i-1][k][2] + dp[i-1][k-2][1] + dp[i-1][k-1][0] + dp[i-1][k-1][3]。
因爲第 i 列的狀態是雜色的,所以要考慮如果上一個和當前狀態一樣則加上,如果完全不一樣則要找 k-2 的狀態加上,如果是純色會多一種,則找 k-1 的狀態加上。
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>
#include <bitset>
#include <stack>
#include <cmath>
#include <deque>
#include <queue>
#include <list>
#include <set>
#include <map>
#define line printf("---------------------------\n")
#define mem(a, b) memset(a, b, sizeof(a))
#define pi acos(-1)
using namespace std;
typedef long long ll;
const double eps = 1e-9;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
const int maxn = 2000+10;
ll dp[1000+10][2000+10][4];
/**
dp[1][1][0,0] = 1;
dp[1][1][0,1] = 0;
dp[1][1][1,0] = 0;
dp[1][1][1,1] = 1;
dp[1][2][0,0] = 0;
dp[1][2][0,1] = 1;
dp[1][2][1,0] = 1;
dp[1][2][1,1] = 0;
dp[2][1][0,0] = 1;
dp[2][1][0,1] = 0;
dp[2][1][1,0] = 0;
dp[2][1][1,1] = 1;
dp[2][2][0,0] = 1+1+1+0;
dp[2][2][0,1] = 1+1+1+0;
dp[2][2][1,0] = 1+1+1+0;
dp[2][2][1,1] = 1+1+1+0;
dp[i][k][0] = dp[i-1][k][1]+dp[i-1][k][2]+dp[i-1][k][0]+dp[i-1][k-1][3];
dp[i][k][3] = dp[i-1][k][1]+dp[i-1][k][2]+dp[i-1][k][3]+dp[i-1][k-1][0];
dp[i][k][1] = dp[i-1][k][1]+dp[i-1][k-2][2]+dp[i-1][k-1][0]+dp[i-1][k-1][3];
dp[i][k][2] = dp[i-1][k][2]+dp[i-1][k-2][1]+dp[i-1][k-1][0]+dp[i-1][k-1][3];
*/
int main(){
int n, K;
while(~scanf("%d %d", &n, &K)){
mem(dp, 0LL);
dp[1][1][0] = 1;
dp[1][1][3] = 1;
dp[1][2][1] = 1;
dp[1][2][2] = 1;
for(int i = 2; i <= n; i++){
dp[i][1][0] = dp[i-1][1][0];
dp[i][1][1] = dp[i-1][1][1];
dp[i][1][2] = dp[i-1][1][2];
dp[i][1][3] = dp[i-1][1][3];
for(int k = 2; k <= K; k++){
dp[i][k][0] = (dp[i-1][k][1]+dp[i-1][k][2]+dp[i-1][k][0]+dp[i-1][k-1][3]) % mod;
dp[i][k][3] = (dp[i-1][k][1]+dp[i-1][k][2]+dp[i-1][k][3]+dp[i-1][k-1][0]) % mod;
dp[i][k][1] = (dp[i-1][k][1]+dp[i-1][k-2][2]+dp[i-1][k-1][0]+dp[i-1][k-1][3]) % mod;
dp[i][k][2] = (dp[i-1][k][2]+dp[i-1][k-2][1]+dp[i-1][k-1][0]+dp[i-1][k-1][3]) % mod;
}
}
printf("%lld\n", (dp[n][K][0]+dp[n][K][1]+dp[n][K][2]+dp[n][K][3])%mod);
}
}