引言 (Introduction)
前五篇我们围绕着二叉搜索树及其平衡变体(AVL、红黑树)展开。这些树的核心约束是节点间的大小关系。但有一种完全不同的应用场景:给定一个数组,我们需要频繁查询某个区间 $[l, r]$ 的聚合值(求和、最大值、最小值),同时数组的元素可能被动态修改。
朴素做法:区间查询 $O(N)$,点更新 $O(1)$。但若查询极为频繁(如每秒数万次),$O(N)$ 的代价不可接受。
线段树(Segment Tree)正是为这种场景量身定做的数据结构。它在 $O(N)$ 的时间内建树,支持 $O(\log N)$ 的单点更新和区间查询,搭配延迟传播(Lazy Propagation)后甚至可以做到 $O(\log N)$ 的区间更新。
1. 线段树的基本结构
1.1 完全二叉树的数组表示
线段树是一棵完全二叉树,每个节点代表数组的一个区间 $[l, r]$:
- 叶子节点:代表单个元素
arr[l] - 内部节点:代表
arr[l..r]的聚合值(如区间和sum[l..r])
由于是完全二叉树,我们可以用数组存储整棵树:
节点 i 的左孩子:2 * i + 1
节点 i 的右孩子:2 * i + 2
节点 i 的父节点:(i - 1) // 2
$N$ 个元素最多需要约 $4N$ 大小的数组(安全上界)。
[0,5] (node 0: sum[0..5])
/ \
[0,2] [3,5] (node 1,2)
/ \ / \
[0,1] [2,2] [3,4] [5,5]
/ \ / \
[0,0] [1,1] [3,3] [4,4]
1.2 建树、点更新、区间查询
以区间求和为例,实现线段树的基础版本:
from typing import List
class SegmentTree:
"""线段树:支持 O(log N) 的点更新与区间求和"""
def __init__(self, data: List[int]):
self.n = len(data)
self.tree: List[int] = [0] * (4 * self.n) # 安全容量
self._build(data, node=0, l=0, r=self.n - 1)
def _build(self, data: List[int], node: int, l: int, r: int) -> None:
"""递归构建线段树"""
if l == r:
self.tree[node] = data[l] # 叶子节点
return
mid = (l + r) // 2
left = 2 * node + 1
right = 2 * node + 2
self._build(data, left, l, mid)
self._build(data, right, mid + 1, r)
self.tree[node] = self.tree[left] + self.tree[right] # 合并子节点
def update(self, idx: int, val: int) -> None:
"""将 data[idx] 更新为 val"""
self._update(node=0, l=0, r=self.n - 1, idx=idx, val=val)
def _update(self, node: int, l: int, r: int, idx: int, val: int) -> None:
if l == r:
self.tree[node] = val
return
mid = (l + r) // 2
left = 2 * node + 1
right = 2 * node + 2
if idx <= mid:
self._update(left, l, mid, idx, val)
else:
self._update(right, mid + 1, r, idx, val)
self.tree[node] = self.tree[left] + self.tree[right]
def query(self, ql: int, qr: int) -> int:
"""查询 [ql, qr] 的区间和"""
return self._query(node=0, l=0, r=self.n - 1, ql=ql, qr=qr)
def _query(self, node: int, l: int, r: int, ql: int, qr: int) -> int:
if ql > r or qr < l: # 查询区间与当前节点区间无交集
return 0
if ql <= l and r <= qr: # 当前节点区间完全在查询区间内
return self.tree[node]
mid = (l + r) // 2
left_sum = self._query(2 * node + 1, l, mid, ql, qr)
right_sum = self._query(2 * node + 2, mid + 1, r, ql, qr)
return left_sum + right_sum
复杂度分析:
_build:每个节点访问一次,$O(N)$。_update:沿树高走一条路径,$O(\log N)$。_query:每层至多访问 4 个节点,共 $O(\log N)$ 层,总复杂度 $O(\log N)$。- 空间:$O(N)$($4N$ 数组)。
2. 延迟传播 (Lazy Propagation)
如果我们需要的是区间更新(例如「将 $[l, r]$ 内所有元素加上 $x$」),基础线段树需要对区间内每个叶子节点做一次更新,复杂度退化到 $O(N \log N)$。
延迟传播的核心思想:当更新指令覆盖了当前节点的整个区间时,我们不立即下沉,而是:更新当前节点的值,在 lazy 数组中记录「未来需要下沉的增量」,然后直接返回。只有在后续查询或更新需要深入该节点的子节点时,才将 lazy 值下推给左右孩子。
class SegmentTreeLazy:
"""带延迟传播的线段树,支持 O(log N) 区间更新"""
def __init__(self, data: List[int]):
self.n = len(data)
self.tree: List[int] = [0] * (4 * self.n)
self.lazy: List[int] = [0] * (4 * self.n) # 延迟标记
self._build(data, node=0, l=0, r=self.n - 1)
def _build(self, data: List[int], node: int, l: int, r: int) -> None:
if l == r:
self.tree[node] = data[l]
return
mid = (l + r) // 2
self._build(data, 2 * node + 1, l, mid)
self._build(data, 2 * node + 2, mid + 1, r)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def _push(self, node: int, l: int, r: int) -> None:
"""将当前节点的 lazy 值下推至左右孩子"""
if self.lazy[node] == 0:
return
mid = (l + r) // 2
left = 2 * node + 1
right = 2 * node + 2
# 将增量传播到孩子
self.tree[left] += self.lazy[node] * (mid - l + 1) # 左孩子区间长度 * 增量
self.lazy[left] += self.lazy[node]
self.tree[right] += self.lazy[node] * (r - mid) # 右孩子区间长度 * 增量
self.lazy[right] += self.lazy[node]
self.lazy[node] = 0 # 清除当前标记
def range_update(self, ql: int, qr: int, val: int) -> None:
"""将 [ql, qr] 内所有元素加上 val"""
self._range_update(node=0, l=0, r=self.n - 1, ql=ql, qr=qr, val=val)
def _range_update(self, node: int, l: int, r: int,
ql: int, qr: int, val: int) -> None:
if ql > r or qr < l: # 无交集
return
if ql <= l and r <= qr: # 完全覆盖
self.tree[node] += val * (r - l + 1) # 更新当前节点
self.lazy[node] += val # 标记延迟更新
return
# 部分覆盖:先下推,再递归更新子节点
self._push(node, l, r)
mid = (l + r) // 2
self._range_update(2 * node + 1, l, mid, ql, qr, val)
self._range_update(2 * node + 2, mid + 1, r, ql, qr, val)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def query(self, ql: int, qr: int) -> int:
"""查询带延迟标记的区间和"""
return self._query(node=0, l=0, r=self.n - 1, ql=ql, qr=qr)
def _query(self, node: int, l: int, r: int, ql: int, qr: int) -> int:
if ql > r or qr < l:
return 0
if ql <= l and r <= qr:
return self.tree[node]
self._push(node, l, r) # 必须先下推
mid = (l + r) // 2
return (self._query(2 * node + 1, l, mid, ql, qr) +
self._query(2 * node + 2, mid + 1, r, ql, qr))
复杂度分析:区间更新与区间查询的复杂度均为 $O(\log N)$——_push 操作只沿着树高的一小部分执行。
3. 坐标离散化
当数据范围极大(如 $[1, 10^9]$)但实际涉及的元素很少(如 $10^5$ 个)时,直接建树会内存溢出。此时需要坐标离散化:
def compress(nums: List[int]) -> List[int]:
"""坐标压缩:将数值映射到 [0, len(unique) - 1] 的连续区间"""
uniq = sorted(set(nums))
rank = {v: i for i, v in enumerate(uniq)}
return [rank[x] for x in nums]
离散化后,线段树的区间范围缩小为 $[0, \text{\#unique} - 1]$,内存需求显著降低。
4. LeetCode 实战应用
4.1 #307 区域和检索 - 数组可修改 (Range Sum Query - Mutable)
题目描述:设计数据结构,支持 update(index, val) 和 sumRange(left, right)。
这是线段树的教科书级应用——标准的可变区间求和问题:
class NumArray:
"""LeetCode 307: 线段树直接应用"""
def __init__(self, nums: List[int]):
self.st = SegmentTree(nums) # 复用上文线段树
def update(self, index: int, val: int) -> None:
self.st.update(index, val)
def sumRange(self, left: int, right: int) -> int:
return self.st.query(left, right)
对比:若使用树状数组(Fenwick Tree / BIT),代码更加简洁(常数更小),但功能限于前缀求和。线段树的优势在于可以灵活切换聚合函数(和、最大值、最小值、GCD 等),且支持区间更新。
4.2 #493 翻转对 (Reverse Pairs)
题目描述:给定数组 nums,找出所有 $i < j$ 且 $\text{nums[i]} > 2 \times \text{nums[j]}$ 的数对。
def reverse_pairs(nums: List[int]) -> int:
"""LeetCode 493: 线段树 + 离散化 + 逆序遍历"""
# Step 1: 坐标离散化(包含原值与 2*原值)
all_vals = sorted(set(list(nums) + [2 * x for x in nums]))
rank = {v: i for i, v in enumerate(all_vals)}
m = len(all_vals)
# Step 2: 线段树按秩统计出现次数
tree = [0] * (4 * m)
def update(node: int, l: int, r: int, idx: int) -> None:
if l == r:
tree[node] += 1
return
mid = (l + r) // 2
if idx <= mid:
update(2 * node + 1, l, mid, idx)
else:
update(2 * node + 2, mid + 1, r, idx)
tree[node] = tree[2 * node + 1] + tree[2 * node + 2]
def query(node: int, l: int, r: int, ql: int, qr: int) -> int:
if ql > r or qr < l:
return 0
if ql <= l and r <= qr:
return tree[node]
mid = (l + r) // 2
return (query(2 * node + 1, l, mid, ql, qr) +
query(2 * node + 2, mid + 1, r, ql, qr))
# Step 3: 从右向左遍历,统计满足 nums[i] > 2*nums[j] 的 j 的数量
ans = 0
for i in range(len(nums) - 1, -1, -1):
# 查询已处理的元素中有多少 < nums[i]/2
target = rank[nums[i]]
ans += query(0, 0, m - 1, 0, target - 1)
# 将 2*nums[i] 标记为已处理
update(0, 0, m - 1, rank[2 * nums[i]])
return ans
核心思想:从右向左遍历,线段树维护已遍历元素(2 × 值)的频次分布。对于当前的 nums[i],只需查询比它小的已遍历元素数量。由于「比它小」对应 nums[j] < nums[i] / 2 即 2 * nums[j] < nums[i],直接查询线段树 $[0, \text{rank[nums[i]]}-1]$ 的和即可。
4.3 #699 掉落的方块 (Falling Squares)
题目描述:正方形依次从空中落下,问每次落下后当前的最大高度。
这是线段树区间最值查询 + 区间更新(Lazy)的经典案例:
from typing import List
def falling_squares(positions: List[List[int]]) -> List[int]:
"""LeetCode 699: 线段树区间最大值 + 惰性更新"""
# Step 1: 坐标离散化
coords = set()
for left, side in positions:
coords.add(left)
coords.add(left + side - 1) # 区间右端点
uniq = sorted(coords)
rank = {v: i for i, v in enumerate(uniq)}
m = len(uniq)
# Step 2: 线段树(区间最大值 + 惰性覆盖)
tree = [0] * (4 * m)
lazy = [0] * (4 * m)
def push(node: int) -> None:
if lazy[node]:
tree[2 * node + 1] = max(tree[2 * node + 1], lazy[node])
lazy[2 * node + 1] = max(lazy[2 * node + 1], lazy[node])
tree[2 * node + 2] = max(tree[2 * node + 2], lazy[node])
lazy[2 * node + 2] = max(lazy[2 * node + 2], lazy[node])
lazy[node] = 0
def range_update(node: int, l: int, r: int, ql: int, qr: int, val: int) -> None:
if ql > r or qr < l:
return
if ql <= l and r <= qr:
tree[node] = max(tree[node], val)
lazy[node] = max(lazy[node], val)
return
push(node)
mid = (l + r) // 2
range_update(2 * node + 1, l, mid, ql, qr, val)
range_update(2 * node + 2, mid + 1, r, ql, qr, val)
tree[node] = max(tree[2 * node + 1], tree[2 * node + 2])
def query_max(node: int, l: int, r: int, ql: int, qr: int) -> int:
if ql > r or qr < l:
return 0
if ql <= l and r <= qr:
return tree[node]
push(node)
mid = (l + r) // 2
return max(query_max(2 * node + 1, l, mid, ql, qr),
query_max(2 * node + 2, mid + 1, r, ql, qr))
# Step 3: 逐个处理方块
result = []
for left, side in positions:
l = rank[left]
r = rank[left + side - 1]
# 查询当前区间最高高度
base_h = query_max(0, 0, m - 1, l, r)
# 叠加当前方块
range_update(0, 0, m - 1, l, r, base_h + side)
# 全局最高高度 = 整棵树的根节点值
result.append(tree[0])
return result
核心思想:将二维方块转化为一维区间,用线段树维护每个水平区间当前的最大高度。每次新方块落下时:先查询区间最高值作为基准,再将该区间全部更新为该基准 + 方块高度。惰性更新避免逐元素修改。
结论 (Conclusion)
线段树是一种将「分治思想」具象化为完全二叉树的数据结构。它的本质是预计算——提前在树的每个节点储存该段区间的聚合信息,从而将查询粒度从 $O(N)$ 压缩到 $O(\log N)$。配合延迟传播,线段树可以高效处理各类区间操作(加、赋值、取 max/min、甚至混合操作)。下一篇我们将进入另一种基于完全二叉树的经典结构——堆与优先队列。