Brief Introduction 线段树是一棵完全二叉树,它的每个节点代表一个区间,子节点的区间是父节点的区间的子区间。
查询任意区间[L, R], 0 <= L <= R < arr.len 的和; 
修改任意一个arr[i]; 
 
如查询[2, 7]之间的和,只需sum([2, 2] + [3, 4] + [5, 7])三个节点之和即可。
Recursive methods for Segment Trees 二叉树一般用数组存储,且如果父节点的下标为i,则:左孩子的下标为2*i+1,右孩子下标为2*i+2。2 ∗ 2 ⌈ log  2 n ⌉ − 1 2 * 2^{\left\lceil\log _{2} n\right\rceil}-1 2 ∗ 2 ⌈ l o g 2  n ⌉ − 1 2 ∗ 2 4 − 1 2 * 2^{4} - 1 2 ∗ 2 4 − 1 ⌈ log  2 l e n ( a r r ) ⌉ = 4 \left\lceil\log _{2} len(arr)\right\rceil = 4 ⌈ log  2  l e n ( a rr ) ⌉ = 4 
Build Segment Tree 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 class  SegmentTree  {    int [] arr;     int [] tree;       public  SegmentTree (int [] arr)  {         this .arr = arr;         int  n  =  arr.length;         int  h  =  1 ;         while (h < n) h <<= 1 ;         this .tree = new  int [2  * h - 1 ];         build(0 , 0 , n - 1 );     }     private  void  build (int  treeIndex, int  range_left, int  range_right)  {         if (range_left == range_right) {             tree[treeIndex] = arr[range_left];             return ;         }         int  mid  =  range_left + (range_right - range_left) / 2 ;         build(2  * treeIndex + 1 , range_left, mid);         build(2  * treeIndex + 2 , mid + 1 , range_right);         tree[treeIndex] = merge(tree[2  * treeIndex + 1 ], tree[2  * treeIndex + 2 ]);     }     private  int  merge (int  nodeA, int  nodeB)  {         return  nodeA + nodeB;     }  } 
递归构建,range from 0 to n-1,即n个叶节点,同时有n-1个非叶节点,build()总共被调用2*n-1次,所以时间复杂度为O(logn).
Query 1 2 3 4 5 6 7 8 9 10 11 12 public  int  query (int  treeIndex, int  range_left, int  range_right, int  query_left, int  query_right)  {    if (range_left > query_right || range_right < query_left) return  0 ;     if (query_left <= range_left && range_right <= query_right) return  tree[treeIndex];     int  mid  =  range_left + (range_right - range_left) / 2 ;     if (mid >= query_right) return  query(2  * treeIndex + 1 , range_left, mid, query_left, query_right);     else  if (mid < query_left) return  query(2  * treeIndex + 2 , mid + 1 , query_left, query_right);     return  merge(res_left, res_right);     int  res_left  =  query(2  * treeIndex + 1 , range_left, mid, query_left, query_right);     int  res_right  =  query(2  * treeIndex + 2 , mid + 1 , query_left, query_right);     return  merge(res_left, res_right); } 
Update 1 2 3 4 5 6 7 8 9 10 11 public  void  update (int  treeIndex, int  range_left, int  range_right, int  arrIndex, int  val)  {    if (range_left == range_right) {         tree[treeIndex] = val;         arr[arrIndex] = val;         return ;     }     int  mid  =  range_left + (range_right - range_left) / 2 ;     if (mid >= arrIndex) update(2  * treeIndex + 1 , range_left, mid, arrIndex, val);     else  update(2  * treeIndex + 2 , mid + 1 , range_right, val);     tree[treeIndex] = merge(tree[2  * treeIndex + 1 ], tree[2  * treeIndex + 2 ]); } 
查询和更新的时间复杂度都为O(logn).
Lazy Propagation 延迟传播是线段树的延伸。
可能对某个父节点更新多次; 
频繁更新的某些节点并不在query范围内; 
 
延迟传播就是解决这两个问题,一个节点只更新一次,不在query范围内的节点不会更新。它的时间复杂度仍是O(logn)。
Update 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 public  void  lazyUpdate (int  treeIndex, int  range_left, int  range_right, int  lo, int  hi, int  val)  {    if (range_left > hi || range_right < lo) return ;     if (lazy[treeIndex] != 0 ) {         tree[treeIndex] += (range_right - range_left + 1 ) * lazy[treeIndex];         if (range_left != range_right) {             lazy[2  * treeIndex + 1 ] += lazy[treeIndex];             lazy[2  * treeIndex + 2 ] += lazy[treeIndex];         }         lazy[treeIndex] = 0 ;     }     if (lo <= range_left && range_right <= hi) {         tree[treeIndex] += val;         if (range_left != range_right) {             lazy[2  * treeIndex + 1 ] += val;             lazy[2  * treeIndex + 2 ] += val;         }         return ;     }     int  mid  =  range_left + (range_right - range_left) / 2 ;     lazyUpdate(2  * treeIndex + 1 , range_left, mid, lo, hi, val);     lazyUpdate(2  * treeIndex + 2 , mid, range_right, lo, hi, val);     tree[treeIndex] = merge(tree[2  * treeIndex + 1 ], tree[2  * treeIndex + 2 ]); } 
Query 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 public  int  lazyQuery (int  treeIndex, int  range_left, int  range_right, int  lo, int  hi)  {    if (range_left > hi || range_right < lo) return ;     if (lazy[treeIndex] != 0 ) {         tree[treeIndex] += (range_right - range_left + 1 ) * lazy[treeIndex];         if (range_left != range_right) {             lazy[2  * treeIndex + 1 ] += lazy[treeIndex];             lazy[2  * treeIndex + 2 ] += lazy[treeIndex];         }         lazy[treeIndex] = 0 ;     }     if (lo <= range_left && range_right <= hi) return  tree[treeIndex];     int  mid  =  range_left + (range_right - range_left) / 2 ;     if (mid >= hi) return  lazyQuery(2  * treeIndex + 1 , range_left, mid, lo, hi);     else  if (mid < lo) return  lazyQuery(2  * treeIndex + 2 , mid + 1 , range_right, lo, hi);     return  merge(lazyQuery(2  * treeIndex + 1 , range_left, mid, lo, hi), lazyQuery(2  * treeIndex + 2 , mid + 1 , range_right, lo, hi)); } 
Practice https://leetcode.com/problems/falling-squares/ 
1 <= positions.length <= 1000.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 class  Solution  {    public  List<Integer> fallingSquares (int [][] positions)  {         int  range_left  =  Integer.MAX_VALUE, range_right = 0 ;         for (int [] pos : positions) {             range_left = Math.min(range_left, pos[0 ]);             range_right = Math.max(range_right, pos[0 ] + pos[1 ]);         }         Node  root  =  new  Node (range_left, range_right - 1 );         List<Integer> ans = new  ArrayList <>(positions.length);         for (int [] pos : positions) {             int  h  =  search(root, pos);             update(root, pos, h + pos[1 ]);             ans.add(root.height);         }         return  ans;     }          private  int  search (Node root, int [] pos)  {         if (root.range_right < pos[0 ] || root.range_left > pos[0 ] + pos[1 ] - 1 ) return  0 ;         if (root.lazy_height != 0 ) {             root.height = Math.max(root.height, root.lazy_height);             if (root.range_left != root.range_right) {                 root.getLeft().lazy_height = Math.max(root.getLeft().lazy_height, root.lazy_height);                 root.getRight().lazy_height = Math.max(root.getRight().lazy_height, root.lazy_height);             }             root.lazy_height = 0 ;         }         if (pos[0 ] <= root.range_left && root.range_right <= pos[0 ] + pos[1 ] - 1 ) return  root.height;         return  Math.max(search(root.getLeft(), pos), search(root.getRight(), pos));     }          private  void  update (Node root, int [] pos, int  h)  {         if (root.range_right < pos[0 ] || root.range_left > pos[0 ] + pos[1 ] - 1 ) return ;         if (root.lazy_height != 0 ) {             root.height = Math.max(root.height, root.lazy_height);             if (root.range_left != root.range_right) {                 root.getLeft().lazy_height = Math.max(root.getLeft().lazy_height, root.lazy_height);                 root.getRight().lazy_height = Math.max(root.getRight().lazy_height, root.lazy_height);             }             root.lazy_height = 0 ;         }               root.height = Math.max(root.height, h);         if (pos[0 ] <= root.range_left && root.range_right <= pos[0 ] + pos[1 ] - 1 ) {             if (root.range_left != root.range_right) {                 root.getLeft().lazy_height = Math.max(root.getLeft().lazy_height, h);                 root.getRight().lazy_height = Math.max(root.getRight().lazy_height, h);             }             return ;         }         update(root.getLeft(), pos, h);         update(root.getRight(), pos, h);     }          class  Node  {         int  range_left, range_right;         int  height;         int  lazy_height;         Node left, right;                  public  Node (int  range_left, int  range_right)  {             this .range_left = range_left;             this .range_right = range_right;         }                  public  int  getRangeMiddle ()  {             return  range_left + (range_right - range_left) / 2 ;         }                  public  Node getLeft ()  {             if (left == null ) left = new  Node (range_left, getRangeMiddle());             return  left;         }                  public  Node getRight ()  {             if (right == null ) right = new  Node (getRangeMiddle() + 1 , range_right);             return  right;         }     } } 
上面的写法使用了坐标的最大最小值,其范围可达 10^8,而数组长度仅1000,所以坐标最多2000个,使用坐标压缩可以将区间范围从10^8缩小到10^3.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 class  Solution  {    public  List<Integer> fallingSquares (int [][] positions)  {         Set<Integer> coords = new  HashSet <>();         for (int [] pos : positions) {             coords.add(pos[0 ]);             coords.add(pos[0 ] + pos[1 ] - 1 );         }         List<Integer> sortedCoords = new  ArrayList <>(coords);         Collections.sort(sortedCoords);         Map<Integer, Integer> index = new  HashMap <>();         int  n  =  0 ;         for (int  i : sortedCoords) index.put(i, n++);         SegmentTree  stree  =  new  SegmentTree (n);         List<Integer> ans = new  ArrayList <>(positions.length);         for (int [] pos : positions) {             int  i  =  index.get(pos[0 ]);             int  j  =  index.get(pos[0 ] + pos[1 ] - 1 );             int  h  =  stree.query(0 , 0 , n - 1 , i, j) + pos[1 ];             stree.update(0 , 0 , n - 1 , i, j, h);             ans.add(stree.tree[0 ]);         }         return  ans;     }          class  SegmentTree  {         int [] tree;         int [] lazy;                  public  SegmentTree (int  n)  {             int  t  =  1 ;             while (t < n) t <<= 1 ;             tree = new  int [2  * t - 1 ];             lazy = new  int [tree.length];         }                  public  int  query (int  treeIndex, int  lo, int  hi, int  i, int  j)  {             if (lo > j || hi < i) return  0 ;             if (lazy[treeIndex] != 0 ) {                 tree[treeIndex] = Math.max(tree[treeIndex], lazy[treeIndex]);                 if (lo != hi) {                     lazy[2  * treeIndex + 1 ] = lazy[treeIndex];                     lazy[2  * treeIndex + 2 ] = lazy[treeIndex];                 }                 lazy[treeIndex] = 0 ;             }             if (i <= lo && hi <= j) return  tree[treeIndex];             int  mid  =  lo + (hi - lo) / 2 ;             return  Math.max(query(2  * treeIndex + 1 , lo, mid, i, j), query(2  * treeIndex + 2 , mid + 1 , hi, i, j));         }         public  void  update (int  treeIndex, int  lo, int  hi, int  i, int  j, int  height)  {             if (lo > j || hi < i) return ;             if (lazy[treeIndex] != 0 ) {                 tree[treeIndex] = Math.max(tree[treeIndex], lazy[treeIndex]);                 if (lo != hi) {                     lazy[2  * treeIndex + 1 ] = Math.max(lazy[2  * treeIndex + 1 ], lazy[treeIndex]);                     lazy[2  * treeIndex + 2 ] = Math.max(lazy[2  * treeIndex + 2 ], lazy[treeIndex]);                 }                 lazy[treeIndex] = 0 ;             }             if (i <= lo && hi <= j) {                 tree[treeIndex] = Math.max(tree[treeIndex], height);                 if (lo != hi) {                     lazy[2  * treeIndex + 1 ] = Math.max(lazy[2  * treeIndex + 1 ], height);                     lazy[2  * treeIndex + 2 ] = Math.max(lazy[2  * treeIndex + 2 ], height);                 }                 return ;             }             int  mid  =  lo + (hi - lo) / 2 ;             update(2  * treeIndex + 1 , lo, mid, i, j, height);             update(2  * treeIndex + 2 , mid + 1 , hi, i, j, height);             tree[treeIndex] = Math.max(tree[2  * treeIndex + 1 ], tree[2  * treeIndex + 2 ]);         }     } } 
2251. Number of Flowers in Full Bloom 
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 class  Solution :    def  fullBloomFlowers (self, flowers: List [List [int ]], persons: List [int ] ) -> List [int ]:         points = set ()         for  x, y in  flowers:             points.add(x)             points.add(y)         for  x in  persons:             points.add(x)         sorted_points = sorted (points)         mp = {x: i for  i, x in  enumerate (sorted_points)}         root = Node(0 , len (sorted_points))         for  x, y in  flowers:             insert(root, mp[x], mp[y], 1 )         ans = []         for  x in  persons:             i = mp[x]             ans.append(query(root, i, i))         return  ans class  Node :    __slots__ = 'l' , 'r' , 'v' , 'lazy' , 'left' , 'right'      def  __init__ (self, l: int , r: int , v: int  = 0 , lazy: int  = 0  ):         self .l = l         self .r = r         self .v = v         self .lazy = lazy         self .left = None  if  l == r else  Node(l, self ._mid())         self .right = None  if  l == r else  Node(self ._mid() + 1 , r)     def  _mid (self ):         return  (self .l + self .r) >> 1  def  insert (root: Node, l: int , r: int , v: int  ):    if  root.r < l or  r < root.l:         return      elif  l <= root.l and  root.r <= r:         root.v += v         if  l != r:             root.lazy += v     else :         insert(root.left, l, r, v)         insert(root.right, l, r, v)         root.v += root.left.v + root.right.v def  query (root: Node, l: int , r: int  ):    if  root.r < l or  r < root.l:         return  0      elif  l <= root.l and  root.r <= r:         return  root.v     else :         if  root.lazy > 0 :             root.left.v += root.lazy             root.right.v += root.lazy             if  root.left.l != root.left.r:                 root.left.lazy += root.lazy             if  root.right.l != root.right.r:                 root.right.lazy += root.lazy             root.lazy = 0          return  query(root.left, l, r) + query(root.right, l, r)         
Reference https://leetcode.com/articles/a-recursive-approach-to-segment-trees-range-sum-queries-lazy-propagation/