「線段樹」第 2 節:寫出預處理數組的結構

由於「線段樹」是平衡二叉樹,因此可以使用數組表示

  • 以前我們學習過「堆」,知道「堆」是一棵「完全二叉樹」,因此「堆」可以用數組表示。基於此,我們很自然地想到可以用數組表示「線段樹」;
  • 完全二叉樹的定義:除了最後一層以外,其餘各層的結點數達到最大,並且最後一層所有的結點都連續地、集中地存儲在最左邊;
  • 線段樹雖然不是完全二叉樹,但線段樹是平衡二叉樹,依然也可以用數組表示。

如何構建線段樹、如何實現區間查詢、如何實現區間更新

「自頂向下」遞歸構建線段樹

  • 首先看看「線段樹」長什麼樣;
  • 線段樹是一種二叉樹結構,不過在實現的時候,可以使用數組實現,這一點和「優先隊列」、「並查集」是一致的。

當區間裏結點的個數恰好是 22 的方冪時:
在這裏插入圖片描述

需要多少空間

  • 「線段樹」的一個經典實現是從上到下遞歸構建,這一點很像根據員工人數來定領導的人數,設置多少領導的個數就要看員工有多少人了;
  • 再想一想,我們在開篇對於線段樹的介紹,線段樹適合支持的操作是「查詢」和「更新」,不適用於「添加」和「刪除」。

下面以「員工和領導」爲例,講解從上到下逐步構建線段樹的步驟:我們首先要解決的問題是「一共要設置多少領導」,我們寧可有一些位置沒有人坐,也要讓所有的人都坐下,因此我們在做估計的時候只會放大

線段樹是一顆平衡二叉樹

比較極端的一種情況:

在這裏插入圖片描述

比較一般的一種情況:

在這裏插入圖片描述

  • 線段樹是一棵二叉樹,除了最後一層以外,每一層都是「滿」的;
  • ii 層(ii 從 0 開始計算)的結點個數爲 2i2^i
  • ii 層 之前的所有結點的個數之和:20+21+22++2i1=1×(12i)12=2i1<2i2^0 + 2^1 + 2^2 + \dots + 2^{i-1} = \cfrac{1 \times (1 - 2^i)}{1 - 2} = 2^i - 1 < 2^i
  • 假設 N=2iN = 2^i,最壞情況下,還要佔用下一層,使用 2N2N 空間,第 ii 層 之前的所有結點的個數之和小於 NN
  • 所以 N+N+2N=4NN + N + 2N = 4N 這麼多空間肯定夠了。

根據上面的討論,我們可以寫出線段樹的代碼框架:

Java 代碼:

public class SegmentTree<E> {
    // 一共要給領導和員工準備的椅子,是我們要構建的輔助數據結構
    private E[] tree;
    // 原始的領導和員工數據,這是一個副本
    private E[] data;

    public SegmentTree(E[] arr) {
        this.data = data;
        // 數組初始化
        data = (E[]) new Object[arr.length];
        for (int i = 0; i < arr.length; i++) {
            data[i] = arr[i];
        }
        tree = (E[]) new Object[4 * arr.length];
    }

    public int getSize() {
        return data.length;
    }

    public E get(int index) {
        if (index < 0 || index >= data.length) {
            throw new IllegalArgumentException("Index is illegal.");
        }
        return data[index];
    }

    /**
     * 返回完全二叉樹的數組表示中,索引所表示的元素的左孩子結點的索引
     * 注意:索引編號從 0 開始
     *
     * @param 線段樹的某個結點的索引
     * @return 傳入的結點的左結點的索引
     */
    public int leftChild(int index) {
        return 2 * index + 1;
    }

    /**
     * 返回完全二叉樹的數組表示中,索引所表示的元素的左孩子結點的索引
     * 注意:索引編號從 0 開始
     *
     * @param 線段樹的某個結點的索引
     * @return 傳入的結點的右結點的索引
     */
    public int rightChild(int index) {
        return 2 * index + 2;
    }

}

Python 代碼:

class SegmentTree:

    def __init__(self, arr):
        self.data = arr
        # 開 4 倍大小的空間
        self.tree = [None for _ in range(4 * len(arr))]

    def get_size(self):
        return len(self.data)

    def get(self, index):
        if index < 0 or index >= len(self.data):
            raise Exception("Index is illegal.")
        return self.data[index]

    def __left_child(self, index):
        return 2 * index + 1

    def __right_child(self, index):
        return 2 * index + 2

「力扣」第 303 題:區域和檢索 - 數組不可變

方法:基於線段樹(區間樹)的實現。

Java 代碼:

public class NumArray {

    private SegmentTree<Integer> segmentTree;

    public NumArray(int[] nums) {
        // 把數組傳給線段樹
        if(nums.length>0){
            Integer[] data = new Integer[nums.length];
            for (int i = 0; i < nums.length; i++) {
                data[i] = nums[i];
            }
            segmentTree = new SegmentTree<>(data, (a, b) -> a + b);
        }
    }

    public int sumRange(int i, int j) {
        if(segmentTree==null){
            throw new IllegalArgumentException("Segment Tree is null");
        }
        return segmentTree.query(i, j);
    }

    private interface Merge<E> {
        E merge(E e1, E e2);
    }

    private class SegmentTree<E> {

        private E[] tree;
        private E[] data;
        private Merge<E> merge;

        public SegmentTree(E[] arr, Merge<E> merge) {
            this.data = data;
            this.merge = merge;
            data = (E[]) new Object[arr.length];
            for (int i = 0; i < arr.length; i++) {
                data[i] = arr[i];
            }
            tree = (E[]) new Object[4 * arr.length];
            buildSegmentTree(0, 0, arr.length - 1);
        }

        private void buildSegmentTree(int treeIndex, int l, int r) {
            if (l == r) {
                tree[treeIndex] = data[l]; // data[r],此時對應葉子節點的情況
                return;// return 不能忘記
            }
            int mid = l + (r - l) / 2;
            int leftChild = leftChild(treeIndex);
            int rightChild = rightChild(treeIndex);
            buildSegmentTree(leftChild, l, mid);
            buildSegmentTree(rightChild, mid + 1, r);
            tree[treeIndex] = merge.merge(tree[leftChild], tree[rightChild]);
        }

        // 在一棵子樹裏做區間查詢
        public E query(int dataL, int dataR) {
            if (dataL < 0 || dataL >= data.length || dataR < 0 || dataR >= data.length || dataL > dataR) {
                throw new IllegalArgumentException("Index is illegal.");
            }
            return query(0, 0, data.length - 1, dataL, dataR);
        }

        private E query(int treeIndex, int l, int r, int dataL, int dataR) {
            if (l == dataL && r == dataR) {
                return tree[treeIndex];
            }
            int mid = l + (r - l) / 2;
            int leftChildIndex = leftChild(treeIndex);
            int rightChildIndex = rightChild(treeIndex);
            if (dataR <= mid) {
                return query(leftChildIndex, l, mid, dataL, dataR);
            }
            if (dataL >= mid + 1) {
                return query(rightChildIndex, mid + 1, r, dataL, dataR);
            }
            E leftResult = query(leftChildIndex, l, mid, dataL, mid);
            E rightResult = query(rightChildIndex, mid + 1, r, mid + 1, dataR);
            return merge.merge(leftResult, rightResult);
        }


        public int getSize() {
            return data.length;
        }

        public E get(int index) {
            if (index < 0 || index >= data.length) {
                throw new IllegalArgumentException("Index is illegal.");
            }
            return data[index];
        }

        public int leftChild(int index) {
            return 2 * index + 1;
        }

        public int rightChild(int index) {
            return 2 * index + 2;
        }
    }
}

Python 代碼:

class NumArray:
    class SegmentTree:

        def __init__(self, arr, merge):
            self.data = arr
            # 開 4 倍大小的空間
            self.tree = [None for _ in range(4 * len(arr))]
            if not hasattr(merge, '__call__'):
                raise Exception('不是函數對象')
            self.merge = merge
            self.__build_segment_tree(0, 0, len(self.data) - 1)

        def get_size(self):
            return len(self.data)

        def get(self, index):
            if index < 0 or index >= len(self.data):
                raise Exception("Index is illegal.")
            return self.data[index]

        def __left_child(self, index):
            return 2 * index + 1

        def __right_child(self, index):
            return 2 * index + 2

        def __build_segment_tree(self, tree_index, data_l, data_r):
            # 區間只有 1 個數的時候,線段樹的值,就是數組的值,不必做融合
            if data_l == data_r:
                self.tree[tree_index] = self.data[data_l]
                # 不要忘記 return
                return

            # 然後一分爲二去構建
            mid = data_l + (data_r - data_l) // 2
            left_child = self.__left_child(tree_index)
            right_child = self.__right_child(tree_index)

            self.__build_segment_tree(left_child, data_l, mid)
            self.__build_segment_tree(right_child, mid + 1, data_r)

            # 左右都構建好以後,再構建自己,因此是後續遍歷
            self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])

        def __str__(self):
            # 打印線段樹
            return str([str(ele) for ele in self.tree])

        def query(self, data_l, data_r):
            if data_l < 0 or data_l >= len(self.data) or data_r < 0 or data_r >= len(self.data) or data_l > data_r:
                raise Exception('Index is illegal.')
            return self.__query(0, 0, len(self.data) - 1, data_l, data_r)

        def __query(self, tree_index, tree_l, tree_r, data_l, data_r):
            # 一般而言,線段樹區間肯定會大一些,所以會遞歸查詢下去
            # 如果要查詢的線段樹區間和數據區間完全吻合,把當前線段樹索引的返回回去就可以了
            if tree_l == data_l and tree_r == data_r:
                return self.tree[tree_index]

            mid = tree_l + (tree_r - tree_l) // 2
            # 線段樹的左右兩個索引
            left_child = self.__left_child(tree_index)
            right_child = self.__right_child(tree_index)

            # 因爲構建時是這樣
            # self.__build_segment_tree(left_child, data_l, mid)
            # 所以,如果右邊區間不大於中間索引,就只須要在左子樹查詢就可以了
            if data_r <= mid:
                return self.__query(left_child, tree_l, mid, data_l, data_r)
            # 同理,如果左邊區間 >= mid + 1,就只用在右邊區間找就好了
            # self.__build_segment_tree(right_child, mid + 1, data_r)
            if data_l >= mid + 1:
                return self.__query(right_child, mid + 1, tree_r, data_l, data_r)
            # 橫跨兩邊的時候,先算算左邊,再算算右邊
            left_res = self.__query(left_child, tree_l, mid, data_l, mid)
            right_res = self.__query(right_child, mid + 1, tree_r, mid + 1, data_r)
            return self.merge(left_res, right_res)

    def __init__(self, nums):
        """
        :type nums: List[int]
        """
        if len(nums) > 0:
            self.st = NumArray.SegmentTree(nums, lambda a, b: a + b)

    def sumRange(self, i, j):
        """
        :type i: int
        :type j: int
        :rtype: int
        """
        if self.st is None:
            return 0
        return self.st.query(i, j)

# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# param_1 = obj.sumRange(i,j)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章