Brief Introduction 线段树是一棵完全二叉树,它的每个节点代表一个区间,子节点的区间是父节点的区间的子区间。 如图,是一棵高度h=5的完全二叉树(即:叶节点只存在于第h和第h-1层)。每个节点保存了所涵盖的区间信息。 线段树可以用来快速查询区间信息,例如给定一个数组 arr = [10, 3, 23, 34, 2, 21, 3, 4, 9, 1];
查询任意区间[L, R], 0 <= L <= R < arr.len 的和;
修改任意一个arr[i];
如查询[2, 7]之间的和,只需sum([2, 2] + [3, 4] + [5, 7])三个节点之和即可。 如修改arr[2],只需修改节点[0, 9] [0, 4] [0, 2] [2, 2] 四个节点即可。 查询和修改的时间复杂度都为O(logN),N=len(arr),即arr的长度。而建树的时间复杂度为O(N).
Recursive methods for Segment Trees 二叉树一般用数组存储,且如果父节点的下标为i,则:左孩子的下标为2*i+1,右孩子下标为2*i+2。 完全二叉树的性质:有n个叶节点,就有n-1个非叶节点。 用数组tree[]保存线段树的节点,则给定一个len(arr)=n的数组,至少需要len(tree) >= 2*n-1构建线段树。如果n是2的幂,则len(tree)=2*n-1,否则len(tree) = 2 ∗ 2 ⌈ log 2 n ⌉ − 1 2 * 2^{\left\lceil\log _{2} n\right\rceil}-1 2 ∗ 2 ⌈ l o g 2 n ⌉ − 1 .例如:len(arr) = 4, 则线段树的结构为: (注:绿色为叶节点,index为该节点在tree[]中的位置) 其大小len(tree) = 2 * len(arr) - 1 = 7. 若len(arr) = 10, 则线段树的结构为: 其大小len(tree) = 2 ∗ 2 4 − 1 2 * 2^{4} - 1 2 ∗ 2 4 − 1 = 31. (⌈ log 2 l e n ( a r r ) ⌉ = 4 \left\lceil\log _{2} len(arr)\right\rceil = 4 ⌈ log 2 l e n ( a rr ) ⌉ = 4 ) 有12个无用叶节点,有2 * len(arr) - 1 = 19个有用节点(10个叶节点+9个非叶节点)。 由此可以看出,当len(arr)为2的幂时,线段树无无用节点,最后一个叶节点的index刚好等于2*len(arr)-1 - 1;当len(arr)不是2的幂时,树中存在无用节点,最后一个叶节点的index可能大于2*len(arr)-1(当len(arr)=10,最后一个叶节点index=24),除非无用节点只出现在最底层末端(例如len(arr)=3、5or7时)。无用节点必然成对儿存在。(因为有一个叶节点就有两个无用节点)
Build Segment Tree 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 class SegmentTree { int [] arr; int [] tree; public SegmentTree (int [] arr) { this .arr = arr; int n = arr.length; int h = 1 ; while (h < n) h <<= 1 ; this .tree = new int [2 * h - 1 ]; build(0 , 0 , n - 1 ); } private void build (int treeIndex, int range_left, int range_right) { if (range_left == range_right) { tree[treeIndex] = arr[range_left]; return ; } int mid = range_left + (range_right - range_left) / 2 ; build(2 * treeIndex + 1 , range_left, mid); build(2 * treeIndex + 2 , mid + 1 , range_right); tree[treeIndex] = merge(tree[2 * treeIndex + 1 ], tree[2 * treeIndex + 2 ]); } private int merge (int nodeA, int nodeB) { return nodeA + nodeB; } }
递归构建,range from 0 to n-1,即n个叶节点,同时有n-1个非叶节点,build()总共被调用2*n-1次,所以时间复杂度为O(logn).
Query 1 2 3 4 5 6 7 8 9 10 11 12 public int query (int treeIndex, int range_left, int range_right, int query_left, int query_right) { if (range_left > query_right || range_right < query_left) return 0 ; if (query_left <= range_left && range_right <= query_right) return tree[treeIndex]; int mid = range_left + (range_right - range_left) / 2 ; if (mid >= query_right) return query(2 * treeIndex + 1 , range_left, mid, query_left, query_right); else if (mid < query_left) return query(2 * treeIndex + 2 , mid + 1 , query_left, query_right); return merge(res_left, res_right); int res_left = query(2 * treeIndex + 1 , range_left, mid, query_left, query_right); int res_right = query(2 * treeIndex + 2 , mid + 1 , query_left, query_right); return merge(res_left, res_right); }
Update 1 2 3 4 5 6 7 8 9 10 11 public void update (int treeIndex, int range_left, int range_right, int arrIndex, int val) { if (range_left == range_right) { tree[treeIndex] = val; arr[arrIndex] = val; return ; } int mid = range_left + (range_right - range_left) / 2 ; if (mid >= arrIndex) update(2 * treeIndex + 1 , range_left, mid, arrIndex, val); else update(2 * treeIndex + 2 , mid + 1 , range_right, val); tree[treeIndex] = merge(tree[2 * treeIndex + 1 ], tree[2 * treeIndex + 2 ]); }
查询和更新的时间复杂度都为O(logn).
Lazy Propagation 延迟传播是线段树的延伸。 上面的update()只是更新某一个值,时间复杂度是O(logn)。如果要更新一个范围的值,效率自然就低了。因为:
可能对某个父节点更新多次;
频繁更新的某些节点并不在query范围内;
延迟传播就是解决这两个问题,一个节点只更新一次,不在query范围内的节点不会更新。它的时间复杂度仍是O(logn)。 延迟传播正如它的名字所言,它的更新操作是延迟的,即只有当该节点被access/query时再更新。如上述求区间和问题,我们增加一个数组lazy[],lazy[i]不为0表示该节点有增量。
Update 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 public void lazyUpdate (int treeIndex, int range_left, int range_right, int lo, int hi, int val) { if (range_left > hi || range_right < lo) return ; if (lazy[treeIndex] != 0 ) { tree[treeIndex] += (range_right - range_left + 1 ) * lazy[treeIndex]; if (range_left != range_right) { lazy[2 * treeIndex + 1 ] += lazy[treeIndex]; lazy[2 * treeIndex + 2 ] += lazy[treeIndex]; } lazy[treeIndex] = 0 ; } if (lo <= range_left && range_right <= hi) { tree[treeIndex] += val; if (range_left != range_right) { lazy[2 * treeIndex + 1 ] += val; lazy[2 * treeIndex + 2 ] += val; } return ; } int mid = range_left + (range_right - range_left) / 2 ; lazyUpdate(2 * treeIndex + 1 , range_left, mid, lo, hi, val); lazyUpdate(2 * treeIndex + 2 , mid, range_right, lo, hi, val); tree[treeIndex] = merge(tree[2 * treeIndex + 1 ], tree[2 * treeIndex + 2 ]); }
Query 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 public int lazyQuery (int treeIndex, int range_left, int range_right, int lo, int hi) { if (range_left > hi || range_right < lo) return ; if (lazy[treeIndex] != 0 ) { tree[treeIndex] += (range_right - range_left + 1 ) * lazy[treeIndex]; if (range_left != range_right) { lazy[2 * treeIndex + 1 ] += lazy[treeIndex]; lazy[2 * treeIndex + 2 ] += lazy[treeIndex]; } lazy[treeIndex] = 0 ; } if (lo <= range_left && range_right <= hi) return tree[treeIndex]; int mid = range_left + (range_right - range_left) / 2 ; if (mid >= hi) return lazyQuery(2 * treeIndex + 1 , range_left, mid, lo, hi); else if (mid < lo) return lazyQuery(2 * treeIndex + 2 , mid + 1 , range_right, lo, hi); return merge(lazyQuery(2 * treeIndex + 1 , range_left, mid, lo, hi), lazyQuery(2 * treeIndex + 2 , mid + 1 , range_right, lo, hi)); }
Practice https://leetcode.com/problems/falling-squares/ 题:给一组俄罗斯方块,都是正方形,每个方块用[i, len]表示,i表示该方块左下角坠落的x轴坐标,len表示边长。求每个方块坠落后,最高方块的y轴坐标。 例1: Input: [[1, 2], [2, 3], [6, 1]] Output: [2, 5, 5] 例2: Input: [[100, 100], [200, 100]] Output: [100, 100]
1 <= positions.length <= 1000. 1 <= positions[i][0] <= 10^8. 1 <= positions[i][1] <= 10^6.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 class Solution { public List<Integer> fallingSquares (int [][] positions) { int range_left = Integer.MAX_VALUE, range_right = 0 ; for (int [] pos : positions) { range_left = Math.min(range_left, pos[0 ]); range_right = Math.max(range_right, pos[0 ] + pos[1 ]); } Node root = new Node (range_left, range_right - 1 ); List<Integer> ans = new ArrayList <>(positions.length); for (int [] pos : positions) { int h = search(root, pos); update(root, pos, h + pos[1 ]); ans.add(root.height); } return ans; } private int search (Node root, int [] pos) { if (root.range_right < pos[0 ] || root.range_left > pos[0 ] + pos[1 ] - 1 ) return 0 ; if (root.lazy_height != 0 ) { root.height = Math.max(root.height, root.lazy_height); if (root.range_left != root.range_right) { root.getLeft().lazy_height = Math.max(root.getLeft().lazy_height, root.lazy_height); root.getRight().lazy_height = Math.max(root.getRight().lazy_height, root.lazy_height); } root.lazy_height = 0 ; } if (pos[0 ] <= root.range_left && root.range_right <= pos[0 ] + pos[1 ] - 1 ) return root.height; return Math.max(search(root.getLeft(), pos), search(root.getRight(), pos)); } private void update (Node root, int [] pos, int h) { if (root.range_right < pos[0 ] || root.range_left > pos[0 ] + pos[1 ] - 1 ) return ; if (root.lazy_height != 0 ) { root.height = Math.max(root.height, root.lazy_height); if (root.range_left != root.range_right) { root.getLeft().lazy_height = Math.max(root.getLeft().lazy_height, root.lazy_height); root.getRight().lazy_height = Math.max(root.getRight().lazy_height, root.lazy_height); } root.lazy_height = 0 ; } root.height = Math.max(root.height, h); if (pos[0 ] <= root.range_left && root.range_right <= pos[0 ] + pos[1 ] - 1 ) { if (root.range_left != root.range_right) { root.getLeft().lazy_height = Math.max(root.getLeft().lazy_height, h); root.getRight().lazy_height = Math.max(root.getRight().lazy_height, h); } return ; } update(root.getLeft(), pos, h); update(root.getRight(), pos, h); } class Node { int range_left, range_right; int height; int lazy_height; Node left, right; public Node (int range_left, int range_right) { this .range_left = range_left; this .range_right = range_right; } public int getRangeMiddle () { return range_left + (range_right - range_left) / 2 ; } public Node getLeft () { if (left == null ) left = new Node (range_left, getRangeMiddle()); return left; } public Node getRight () { if (right == null ) right = new Node (getRangeMiddle() + 1 , range_right); return right; } } }
上面的写法使用了坐标的最大最小值,其范围可达 10^8,而数组长度仅1000,所以坐标最多2000个,使用坐标压缩可以将区间范围从10^8缩小到10^3.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 class Solution { public List<Integer> fallingSquares (int [][] positions) { Set<Integer> coords = new HashSet <>(); for (int [] pos : positions) { coords.add(pos[0 ]); coords.add(pos[0 ] + pos[1 ] - 1 ); } List<Integer> sortedCoords = new ArrayList <>(coords); Collections.sort(sortedCoords); Map<Integer, Integer> index = new HashMap <>(); int n = 0 ; for (int i : sortedCoords) index.put(i, n++); SegmentTree stree = new SegmentTree (n); List<Integer> ans = new ArrayList <>(positions.length); for (int [] pos : positions) { int i = index.get(pos[0 ]); int j = index.get(pos[0 ] + pos[1 ] - 1 ); int h = stree.query(0 , 0 , n - 1 , i, j) + pos[1 ]; stree.update(0 , 0 , n - 1 , i, j, h); ans.add(stree.tree[0 ]); } return ans; } class SegmentTree { int [] tree; int [] lazy; public SegmentTree (int n) { int t = 1 ; while (t < n) t <<= 1 ; tree = new int [2 * t - 1 ]; lazy = new int [tree.length]; } public int query (int treeIndex, int lo, int hi, int i, int j) { if (lo > j || hi < i) return 0 ; if (lazy[treeIndex] != 0 ) { tree[treeIndex] = Math.max(tree[treeIndex], lazy[treeIndex]); if (lo != hi) { lazy[2 * treeIndex + 1 ] = lazy[treeIndex]; lazy[2 * treeIndex + 2 ] = lazy[treeIndex]; } lazy[treeIndex] = 0 ; } if (i <= lo && hi <= j) return tree[treeIndex]; int mid = lo + (hi - lo) / 2 ; return Math.max(query(2 * treeIndex + 1 , lo, mid, i, j), query(2 * treeIndex + 2 , mid + 1 , hi, i, j)); } public void update (int treeIndex, int lo, int hi, int i, int j, int height) { if (lo > j || hi < i) return ; if (lazy[treeIndex] != 0 ) { tree[treeIndex] = Math.max(tree[treeIndex], lazy[treeIndex]); if (lo != hi) { lazy[2 * treeIndex + 1 ] = Math.max(lazy[2 * treeIndex + 1 ], lazy[treeIndex]); lazy[2 * treeIndex + 2 ] = Math.max(lazy[2 * treeIndex + 2 ], lazy[treeIndex]); } lazy[treeIndex] = 0 ; } if (i <= lo && hi <= j) { tree[treeIndex] = Math.max(tree[treeIndex], height); if (lo != hi) { lazy[2 * treeIndex + 1 ] = Math.max(lazy[2 * treeIndex + 1 ], height); lazy[2 * treeIndex + 2 ] = Math.max(lazy[2 * treeIndex + 2 ], height); } return ; } int mid = lo + (hi - lo) / 2 ; update(2 * treeIndex + 1 , lo, mid, i, j, height); update(2 * treeIndex + 2 , mid + 1 , hi, i, j, height); tree[treeIndex] = Math.max(tree[2 * treeIndex + 1 ], tree[2 * treeIndex + 2 ]); } } }
2251. Number of Flowers in Full Bloom
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 class Solution : def fullBloomFlowers (self, flowers: List [List [int ]], persons: List [int ] ) -> List [int ]: points = set () for x, y in flowers: points.add(x) points.add(y) for x in persons: points.add(x) sorted_points = sorted (points) mp = {x: i for i, x in enumerate (sorted_points)} root = Node(0 , len (sorted_points)) for x, y in flowers: insert(root, mp[x], mp[y], 1 ) ans = [] for x in persons: i = mp[x] ans.append(query(root, i, i)) return ans class Node : __slots__ = 'l' , 'r' , 'v' , 'lazy' , 'left' , 'right' def __init__ (self, l: int , r: int , v: int = 0 , lazy: int = 0 ): self.l = l self.r = r self.v = v self.lazy = lazy self.left = None if l == r else Node(l, self._mid()) self.right = None if l == r else Node(self._mid() + 1 , r) def _mid (self ): return (self.l + self.r) >> 1 def insert (root: Node, l: int , r: int , v: int ): if root.r < l or r < root.l: return elif l <= root.l and root.r <= r: root.v += v if l != r: root.lazy += v else : insert(root.left, l, r, v) insert(root.right, l, r, v) root.v += root.left.v + root.right.v def query (root: Node, l: int , r: int ): if root.r < l or r < root.l: return 0 elif l <= root.l and root.r <= r: return root.v else : if root.lazy > 0 : root.left.v += root.lazy root.right.v += root.lazy if root.left.l != root.left.r: root.left.lazy += root.lazy if root.right.l != root.right.r: root.right.lazy += root.lazy root.lazy = 0 return query(root.left, l, r) + query(root.right, l, r)
Reference https://leetcode.com/articles/a-recursive-approach-to-segment-trees-range-sum-queries-lazy-propagation/