算法:矩阵快速幂

矩阵快速幂

功能

快速计算矩阵AAbb次方幂

思路

将快速幂算法中的乘法运算替换为矩阵乘法。若将bb表示为pi×2i\sum p_i \times 2_i,则AbA^b可以表示为(A2i)pi\prod (A^{2^i})^{p_i},其中pip_i表示bb的二进制从右往左第ii位数字。

时间复杂度

O(logn)O(\log n)

测试

POJ:3070

应用

可根据递推公式快速求序列某项的值。如FibonacciFibonacci数列的递推公式可视为
$$
\left[
\begin{matrix}
a_n \
a_{n-1}
\end{matrix}
\right]

\left[
\begin{matrix}
1 & 1 \
1 & 0
\end{matrix}
\right]
\left[
\begin{matrix}
a_{n-1} \
a_{n-2}
\end{matrix}
\right]
$$

模板

typedef long long LL;
const int maxn = 100;

struct Mat {
  LL m[maxn][maxn];
};

/**
  * @param a: the matrix A
  * @param b: the matrix B
  * @return: A*B
  */
Mat mul(Mat a, Mat b, int n) {
  Mat ans;
  memset(ans.m, 0, sizeof(ans.m));
  for (int i = 1; i <= n; ++i) {
    for (int j = 1; j <= n; ++j) {
      for (int k = 1; k <= n; ++k) {
        ans.m[i][j] += a.m[i][k] * b.m[k][j];
      }
    }
  }
  return ans;
}

/**
  * @param a: the base matrix A
  * @param b: the exponent of power
  * @return: A^b
  * @other: b >= 0
  */
Mat FPM(Mat a, int b, int n) {
  Mat ans;
  memset(ans.m, 0, sizeof(ans.m));
  for (int i = 1; i <= n; ++i) {
    ans.m[i][i] = 1;
  }
  while (b > 0) {
    if (b & 1) ans = mul(ans, a);
    a = mul(a, a);
    b >>= 1;
  }
  return ans;
}

扩展

对于较大的数需要取模。

模板

#include <cstring>

typedef long long LL;
const int maxn = 100;
const int mod = 1e9+7;  // the divisor of answer

struct Mat {
  LL m[maxn][maxn];
};

/**
  * @param a: the matrix A
  * @param b: the matrix B
  * @return: A*B
  */
Mat mul(Mat a, Mat b, int n) {
  Mat ans;
  memset(ans.m, 0, sizeof(ans.m));
  for (int i = 1; i <= n; ++i) {
    for (int j = 1; j <= n; ++j) {
      for (int k = 1; k <= n; ++k) {
        ans.m[i][j] += a.m[i][k] * b.m[k][j];
        ans.m[i][j] %= mod;
      }
    }
  }
  return ans;
}

/**
  * @param a: the base matrix A
  * @param b: the exponent of power
  * @return: A^b
  * @other: b >= 0
  */
Mat FPM(Mat a, int b, int n) {
  for (int i = 1; i <= n; ++i) {
    for (int j = 1; j <= n; ++j) {
	  a.m[i][j] %= mod;
    }
  }
  Mat ans;
  memset(ans.m, 0, sizeof(ans.m));
  for (int i = 1; i <= n; ++i) {
    ans.m[i][i] = 1;
  }
  while (b > 0) {
    if (b & 1) ans = mul(ans, a);
    a = mul(a, a);
    b >>= 1;
  }
  return ans;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章