由於「線段樹」是平衡二叉樹,因此可以使用數組表示
- 以前我們學習過「堆」,知道「堆」是一棵「完全二叉樹」,因此「堆」可以用數組表示。基於此,我們很自然地想到可以用數組表示「線段樹」;
- 完全二叉樹的定義:除了最後一層以外,其餘各層的結點數達到最大,並且最後一層所有的結點都連續地、集中地存儲在最左邊;
- 線段樹雖然不是完全二叉樹,但線段樹是平衡二叉樹,依然也可以用數組表示。
如何構建線段樹、如何實現區間查詢、如何實現區間更新
「自頂向下」遞歸構建線段樹
- 首先看看「線段樹」長什麼樣;
- 線段樹是一種二叉樹結構,不過在實現的時候,可以使用數組實現,這一點和「優先隊列」、「並查集」是一致的。
當區間裏結點的個數恰好是 的方冪時:
需要多少空間
- 「線段樹」的一個經典實現是從上到下遞歸構建,這一點很像根據員工人數來定領導的人數,設置多少領導的個數就要看員工有多少人了;
- 再想一想,我們在開篇對於線段樹的介紹,線段樹適合支持的操作是「查詢」和「更新」,不適用於「添加」和「刪除」。
下面以「員工和領導」爲例,講解從上到下逐步構建線段樹的步驟:我們首先要解決的問題是「一共要設置多少領導」,我們寧可有一些位置沒有人坐,也要讓所有的人都坐下,因此我們在做估計的時候只會放大。
線段樹是一顆平衡二叉樹
比較極端的一種情況:
比較一般的一種情況:
- 線段樹是一棵二叉樹,除了最後一層以外,每一層都是「滿」的;
- 第 層( 從 0 開始計算)的結點個數爲 ;
- 第 層 之前的所有結點的個數之和:
- 假設 ,最壞情況下,還要佔用下一層,使用 空間,第 層 之前的所有結點的個數之和小於 ;
- 所以 這麼多空間肯定夠了。
根據上面的討論,我們可以寫出線段樹的代碼框架:
Java 代碼:
public class SegmentTree<E> {
// 一共要給領導和員工準備的椅子,是我們要構建的輔助數據結構
private E[] tree;
// 原始的領導和員工數據,這是一個副本
private E[] data;
public SegmentTree(E[] arr) {
this.data = data;
// 數組初始化
data = (E[]) new Object[arr.length];
for (int i = 0; i < arr.length; i++) {
data[i] = arr[i];
}
tree = (E[]) new Object[4 * arr.length];
}
public int getSize() {
return data.length;
}
public E get(int index) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal.");
}
return data[index];
}
/**
* 返回完全二叉樹的數組表示中,索引所表示的元素的左孩子結點的索引
* 注意:索引編號從 0 開始
*
* @param 線段樹的某個結點的索引
* @return 傳入的結點的左結點的索引
*/
public int leftChild(int index) {
return 2 * index + 1;
}
/**
* 返回完全二叉樹的數組表示中,索引所表示的元素的左孩子結點的索引
* 注意:索引編號從 0 開始
*
* @param 線段樹的某個結點的索引
* @return 傳入的結點的右結點的索引
*/
public int rightChild(int index) {
return 2 * index + 2;
}
}
Python 代碼:
class SegmentTree:
def __init__(self, arr):
self.data = arr
# 開 4 倍大小的空間
self.tree = [None for _ in range(4 * len(arr))]
def get_size(self):
return len(self.data)
def get(self, index):
if index < 0 or index >= len(self.data):
raise Exception("Index is illegal.")
return self.data[index]
def __left_child(self, index):
return 2 * index + 1
def __right_child(self, index):
return 2 * index + 2
「力扣」第 303 題:區域和檢索 - 數組不可變
- 題目鏈接:303. 區域和檢索 - 數組不可變
方法:基於線段樹(區間樹)的實現。
Java 代碼:
public class NumArray {
private SegmentTree<Integer> segmentTree;
public NumArray(int[] nums) {
// 把數組傳給線段樹
if(nums.length>0){
Integer[] data = new Integer[nums.length];
for (int i = 0; i < nums.length; i++) {
data[i] = nums[i];
}
segmentTree = new SegmentTree<>(data, (a, b) -> a + b);
}
}
public int sumRange(int i, int j) {
if(segmentTree==null){
throw new IllegalArgumentException("Segment Tree is null");
}
return segmentTree.query(i, j);
}
private interface Merge<E> {
E merge(E e1, E e2);
}
private class SegmentTree<E> {
private E[] tree;
private E[] data;
private Merge<E> merge;
public SegmentTree(E[] arr, Merge<E> merge) {
this.data = data;
this.merge = merge;
data = (E[]) new Object[arr.length];
for (int i = 0; i < arr.length; i++) {
data[i] = arr[i];
}
tree = (E[]) new Object[4 * arr.length];
buildSegmentTree(0, 0, arr.length - 1);
}
private void buildSegmentTree(int treeIndex, int l, int r) {
if (l == r) {
tree[treeIndex] = data[l]; // data[r],此時對應葉子節點的情況
return;// return 不能忘記
}
int mid = l + (r - l) / 2;
int leftChild = leftChild(treeIndex);
int rightChild = rightChild(treeIndex);
buildSegmentTree(leftChild, l, mid);
buildSegmentTree(rightChild, mid + 1, r);
tree[treeIndex] = merge.merge(tree[leftChild], tree[rightChild]);
}
// 在一棵子樹裏做區間查詢
public E query(int dataL, int dataR) {
if (dataL < 0 || dataL >= data.length || dataR < 0 || dataR >= data.length || dataL > dataR) {
throw new IllegalArgumentException("Index is illegal.");
}
return query(0, 0, data.length - 1, dataL, dataR);
}
private E query(int treeIndex, int l, int r, int dataL, int dataR) {
if (l == dataL && r == dataR) {
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftChildIndex = leftChild(treeIndex);
int rightChildIndex = rightChild(treeIndex);
if (dataR <= mid) {
return query(leftChildIndex, l, mid, dataL, dataR);
}
if (dataL >= mid + 1) {
return query(rightChildIndex, mid + 1, r, dataL, dataR);
}
E leftResult = query(leftChildIndex, l, mid, dataL, mid);
E rightResult = query(rightChildIndex, mid + 1, r, mid + 1, dataR);
return merge.merge(leftResult, rightResult);
}
public int getSize() {
return data.length;
}
public E get(int index) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal.");
}
return data[index];
}
public int leftChild(int index) {
return 2 * index + 1;
}
public int rightChild(int index) {
return 2 * index + 2;
}
}
}
Python 代碼:
class NumArray:
class SegmentTree:
def __init__(self, arr, merge):
self.data = arr
# 開 4 倍大小的空間
self.tree = [None for _ in range(4 * len(arr))]
if not hasattr(merge, '__call__'):
raise Exception('不是函數對象')
self.merge = merge
self.__build_segment_tree(0, 0, len(self.data) - 1)
def get_size(self):
return len(self.data)
def get(self, index):
if index < 0 or index >= len(self.data):
raise Exception("Index is illegal.")
return self.data[index]
def __left_child(self, index):
return 2 * index + 1
def __right_child(self, index):
return 2 * index + 2
def __build_segment_tree(self, tree_index, data_l, data_r):
# 區間只有 1 個數的時候,線段樹的值,就是數組的值,不必做融合
if data_l == data_r:
self.tree[tree_index] = self.data[data_l]
# 不要忘記 return
return
# 然後一分爲二去構建
mid = data_l + (data_r - data_l) // 2
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
self.__build_segment_tree(left_child, data_l, mid)
self.__build_segment_tree(right_child, mid + 1, data_r)
# 左右都構建好以後,再構建自己,因此是後續遍歷
self.tree[tree_index] = self.merge(self.tree[left_child], self.tree[right_child])
def __str__(self):
# 打印線段樹
return str([str(ele) for ele in self.tree])
def query(self, data_l, data_r):
if data_l < 0 or data_l >= len(self.data) or data_r < 0 or data_r >= len(self.data) or data_l > data_r:
raise Exception('Index is illegal.')
return self.__query(0, 0, len(self.data) - 1, data_l, data_r)
def __query(self, tree_index, tree_l, tree_r, data_l, data_r):
# 一般而言,線段樹區間肯定會大一些,所以會遞歸查詢下去
# 如果要查詢的線段樹區間和數據區間完全吻合,把當前線段樹索引的返回回去就可以了
if tree_l == data_l and tree_r == data_r:
return self.tree[tree_index]
mid = tree_l + (tree_r - tree_l) // 2
# 線段樹的左右兩個索引
left_child = self.__left_child(tree_index)
right_child = self.__right_child(tree_index)
# 因爲構建時是這樣
# self.__build_segment_tree(left_child, data_l, mid)
# 所以,如果右邊區間不大於中間索引,就只須要在左子樹查詢就可以了
if data_r <= mid:
return self.__query(left_child, tree_l, mid, data_l, data_r)
# 同理,如果左邊區間 >= mid + 1,就只用在右邊區間找就好了
# self.__build_segment_tree(right_child, mid + 1, data_r)
if data_l >= mid + 1:
return self.__query(right_child, mid + 1, tree_r, data_l, data_r)
# 橫跨兩邊的時候,先算算左邊,再算算右邊
left_res = self.__query(left_child, tree_l, mid, data_l, mid)
right_res = self.__query(right_child, mid + 1, tree_r, mid + 1, data_r)
return self.merge(left_res, right_res)
def __init__(self, nums):
"""
:type nums: List[int]
"""
if len(nums) > 0:
self.st = NumArray.SegmentTree(nums, lambda a, b: a + b)
def sumRange(self, i, j):
"""
:type i: int
:type j: int
:rtype: int
"""
if self.st is None:
return 0
return self.st.query(i, j)
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# param_1 = obj.sumRange(i,j)