【圖論】B040_NK_“好序列”的個數(快速冪 + 求差)

一、描述

現在你面前有一棵n個節點的樹(全連通無環圖)。樹上的邊只有2種顏色,紅色或者黑色。現在還給你一個整數k,考慮下面這個k個節點的序列[a1, a2, …, ak]。

[a1, a2, …, ak]如果是”好序列“當且僅當滿足下面的條件:

  1. 我們要走一條從a1開始到ak結束的路徑。
  2. 從a1開始,到a2走一條a1到a2的最短路。然後從a2開始,繼續走一條到a3的最短路,以此類推,最終到a(k-1)和ak。
  3. 走的路徑中至少包含一條黑色的邊。
    在這裏插入圖片描述

我們看一下上面的圖片中的樹,如果k=3,那麼下面的序列是“好序列”:[1,4,7], [5,5,3]。下面的序列不是好序列: [1,4,6], [5,5,5], [3,7,3]。

總共有 nkn^k(n的k次方種路徑方案),那麼有多少路徑是“好序列”呢?這個值可能非常大,輸出的結果對(10^9+7)取模就可以。

輸入描述:

第一行是2個整數n和k,其中(2 <= n <= 10^5, 2 <= k <= 100),n表示樹的節點個數,k表示序列的長度。

下面n-1行,每行包含3個整數,u[i], v[i], w[i],其中1 <= u[i], v[i] <= n, w[i] = 0或1。u[i], v[i]表示這兩個節點之間有一條邊,w[i]表示這條邊的顏色,其中0表示紅色,1表示黑色。

輸出描述:

輸出所有“好序列”的個數模(10^9+7)

輸入
4 4
1 2 1
2 3 1
3 4 1

輸出
252

說明
這個例子中,所有序列一共有4^4 = 256個,其中不是好序列的只有4個:
[1, 1, 1, 1]

[2, 2, 2, 2]

[3, 3, 3, 3]

[4, 4, 4, 4]

二、Solution

方法一:求差

  • 直接求好序列比較難,因爲遍歷的時候還要統計 黑邊 的數量
  • 而壞序列只包紅邊,又由題意得一個壞子圖的結點數爲 sz 時,那麼該子圖就有 szksz^k 個壞序列,所以我們只需求出每個壞子圖 ii 的結點數 szisz_i,最後用 totszitot - sz_itot=nk(tot = n^k) 即爲所求答案。

細節:一般涉及到取模的題都需要仔細觀察能取模的地方,比如這裏,我提交時沒有寫上 + mod,直接 WA 了。

System.out.println((tot - bad + mod) % mod);

原因是爲了防止負數取模仍是負數的問題,例如 -3 % 4 在 Java 中會得到 -3,而某些題目結果不小於 0,所以要加上 mod 確保結果非負(由取模定理得結果是不會改變的)

import java.util.*;
import java.math.*;
import java.io.*;
public class Main{
    static class Solution {
        Set<Integer> vis, g[];
        int mod = (int) 1e9+7;
        
        long qPow(long b, long p) {
            long ans = 1;
            while (p > 0) {
                if ((p & 1) == 1)
                    ans = (ans * b) % mod;
                b = (b * b) % mod;
                p >>= 1;
            }
            return ans;
        }
        long dfs(int u) {
            long sz = 1;
            vis.add(u);
            for (int v : g[u]) if (!vis.contains(v)) {
                sz += dfs(v);
            }
            return sz;
        }
        void init() {
            Scanner sc = new Scanner(new BufferedInputStream(System.in));
            int n = sc.nextInt(), k = sc.nextInt();
            long tot = qPow(n, k);

            vis = new HashSet<>();
            g = new HashSet[n+1];
            for (int i = 1; i <= n; i++) g[i] = new HashSet<>();
            for (int i = 1; i < n; i++) {
                int a = sc.nextInt(), b = sc.nextInt(), w = sc.nextInt();
                if (w == 0) {
                    g[a].add(b);
                    g[b].add(a);
                }
            }
            long bad = 0;
            for (int i = 1; i <= n; i++) if (!vis.contains(i)) {
                long sz = dfs(i);
                bad = (bad + qPow(sz, k))  % mod;
            }
            System.out.println((tot - bad + mod) % mod);
        }
    }
    public static void main(String[] args) throws IOException {  
        Solution s = new Solution();
        s.init();
    }
}

複雜度分析

  • 時間複雜度:O(n)O(n)
  • 空間複雜度:O(n)O(n)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章