【POJ 3233】Matrix Power Series

【題目鏈接】

          點擊打開鏈接

【算法】

          要求 A^1 + A^2 + A^3 + ... + A^k

          考慮通過二分來計算這個式子 :

          令f(k) = A^1 + A^2 + A ^ 3 + ... + A^k

          那麼,當k爲奇數時,f(k) = f(k-1) + A ^ k

                    當k爲偶數時,f(k) = f(n/2) + A ^ (n/2) * f(n/2)

          因此,可以通過二分 + 矩陣乘法快速冪的方式,在O(n^3log(n)^2)的時間內解決此題

【代碼】

         

#include <algorithm>
#include <bitset>
#include <cctype>
#include <cerrno>
#include <clocale>
#include <cmath>
#include <complex>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <deque>
#include <exception>
#include <fstream>
#include <functional>
#include <limits>
#include <list>
#include <map>
#include <iomanip>
#include <ios>
#include <iosfwd>
#include <iostream>
#include <istream>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stdexcept>
#include <streambuf>
#include <string>
#include <utility>
#include <vector>
#include <cwchar>
#include <cwctype>
#include <stack>
#include <limits.h>
using namespace std;
#define MAXN 35

int i,j,n,k,m;
struct Matrix
{
		int mat[MAXN][MAXN];
} a,ans;

inline Matrix add(Matrix a,Matrix b)
{
		int i,j;
		Matrix ans;
		memset(ans.mat,0,sizeof(ans.mat));
		for (i = 1; i <= n; i++)
		{
				for (j = 1; j <= n; j++)
				{
						ans.mat[i][j] = (a.mat[i][j] + b.mat[i][j]) % m;
				}
		}
		return ans;
}
inline Matrix mul(Matrix a,Matrix b)
{
		int i,j,k;
		Matrix ans;
		memset(ans.mat,0,sizeof(ans.mat));
		for (i = 1; i <= n; i++)
		{
				for (j = 1; j <= n; j++)
				{
						for (k = 1; k <= n; k++)
						{
								ans.mat[i][j] = (ans.mat[i][j] + a.mat[i][k] * b.mat[k][j]) % m;
						}
				}
		}
		return ans;
}
inline Matrix power(Matrix a,int m)
{
		Matrix ans,p = a;
		for (i = 1; i <= n; i++)
		{
				for (j = 1; j <= n; j++)
				{
						ans.mat[i][j] = (i == j);	
				}	
		}	
		while (m > 0)
		{
				if (m & 1) ans = mul(ans,p);
				p = mul(p,p);
				m >>= 1;
		}
		return ans;
}
Matrix solve(int n)
{
		Matrix tmp;
		if (n == 1) return a;
		if (n % 2 == 0)
		{
				tmp = solve(n/2);
				return add(tmp,mul(power(a,n/2),tmp));
		} else return add(solve(n-1),power(a,n));
 }

int main() 
{
		
		scanf("%d%d%d",&n,&k,&m);
		for (i = 1; i <= n; i++)
		{
				for (j = 1; j <= n; j++)
				{
						scanf("%d",&a.mat[i][j]);
				}
		}
		ans = solve(k);
		for (i = 1; i <= n; i++)
		{
				for (j = 1; j < n; j++)
				{
						printf("%d ",ans.mat[i][j]);
				}
				printf("%d\n",ans.mat[i][n]);
		}
		
		return 0;
	
}

         

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章