矩陣快速冪 模板 學習筆記

矩陣快速冪

推薦模板題 洛谷P3390
矩陣乘法時間複雜度:n×mn \times mm×rm \times r的矩陣相乘,複雜度O(nmr)O(nmr)
計算AnA^n.矩陣乘法的次數O(log2n)O(\log_2{n}),總複雜度A3log2n|A|^3\log_2{n}.

// 除非是設置單位矩陣,否則必須調用set_size進行設置大小並清零(或指定值)的初始化
// 所有函數都預設傳入了正確的參數
// 矩陣乘法使用取模版,加減數乘未取模(默認它們不爆) 因爲一般題目也沒這些操作
struct mtr{
    int r_sz, c_sz;
    typedef ll item_type;
    typedef vector<item_type> row_type;
    vector<row_type> data;
    mtr():r_sz(0),c_sz(0),data(){}
    // 設置大小,並且全部元素設置爲item_val值
    void set_size(int r_size, int c_size, int item_val = 0) { 
        r_sz = r_size; c_sz = c_size;
        data.resize(r_sz);
        for (auto &row : data)
            row.resize(c_sz, item_val);
    }

    inline bool is_square() { return r_sz == c_sz; }

    // inline row_type& operator()(int r) { return data[r]; }
    // inline item_type& operator()(int r,int c) { return data[r][c];}

    // 會自動調用set_size,調用之前請勿調用set_size
    // 設置成n階單位矩陣
    void set_identity(int n) {
        set_size(n, n, 0);
        for (int i = 0; i < n; ++i)
            data[i][i] = 1;
    }
    void in() {
        for (int i = 0; i < r_sz; ++i)
            for (int j = 0; j < c_sz; ++j)
                scanf("%lld", &data[i][j]);
    }
    // 矩陣輸出,主要爲了調試
    void out() {
        for (auto &row : data) {
            for (auto &cell : row)
                cout<<cell<<" ";
            cout<<"\n";
        }
    }
    // 矩陣加,假設傳參合法
    mtr operator+(const mtr& obj) const {
        mtr ans;
        ans.set_size(r_sz, c_sz);
        for (int i = 0; i < r_sz; ++i)
            for (int j = 0; j < c_sz; ++j)
                ans.data[i][j] = data[i][j] + obj.data[i][j];
        return ans;
    }
    mtr operator-(const mtr& obj) const {
        mtr ans;
        ans.set_size(r_sz, c_sz);
        for (int i = 0; i < r_sz; ++i)
            for (int j = 0; j < c_sz; ++j)
                ans.data[i][j] = data[i][j] - obj.data[i][j];
        return ans;
    }
    // 矩陣數乘 數在右邊
    // 數乘 數在左邊必須在類外邊用函數實現,模板不提供,容易改出來
    mtr operator*(item_type obj) const {
        mtr ans;
        ans.set_size(r_sz, c_sz);
        for (int i = 0; i < r_sz; ++i)
            for (int j = 0; j < c_sz; ++j)
                ans.data[i][j] = data[i][j] * obj;
        return ans;
    }
    // 所有元素對mod取模(數學意義)
    void get_mod(ll mod) {
        for (int i = 0; i < r_sz; ++i)
            for (int j = 0; j < c_sz; ++j) {
                data[i][j] %= mod;
                if (data[i][j] < 0)
                    data[i][j] += mod;
            }
    }
    // 矩陣乘法 不用運算符乘號進行重載,便於增加mod參數修改成取模版
    // 默認元素乘法不爆long long,否則需要引入mod_sys模板
    // 默認待兩個輸入矩陣已經get_mod規約過了。
    mtr mlt(const mtr& obj, ll mod) const {
        mtr ans;
        ans.set_size(r_sz, obj.c_sz);
        for (int i = 0; i < r_sz; ++i)
            for (int j = 0; j < obj.c_sz; ++j) {
                item_type t = 0;
                for (int k = 0; k < c_sz; ++k)
                    t = (t+(data[i][k]*obj.data[k][j])%mod)%mod;
                ans.data[i][j] = t;
            }
        return ans;
    }
    // 預設n>=0
    mtr pow(ll n, ll mod) const {
        mtr a = *this;
        mtr t;
        t.set_identity(r_sz);
        // (a)^n*t
        if (n == 0) return t;
        while (n>1) {
            if (n&1) t = a.mlt(t, mod);
            n >>= 1; a = a.mlt(a, mod);
        }
        return a.mlt(t, mod);
    }
};
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章