@@ -57,33 +57,34 @@ class TreeNode:
57
57
self .left = - 1 # 区间左边界
58
58
self .right = - 1 # 区间右边界
59
59
self .val = val # 节点值(区间值)
60
- self .lazy_tag = 0 # 区间和问题的延迟更新标记
60
+ self .lazy_tag = None # 区间和问题的延迟更新标记
61
+
61
62
62
63
# 线段树类
63
64
class SegmentTree :
64
65
def __init__ (self , nums , function ):
65
66
self .size = len (nums)
66
- self .tree = [TreeNode() for _ in range (4 * self .size)] # 维护 TreeNode 数组
67
+ self .tree = [TreeNode() for _ in range (4 * self .size)] # 维护 TreeNode 数组
67
68
self .nums = nums # 原始数据
68
69
self .function = function # function 是一个函数,左右区间的聚合方法
69
70
if self .size > 0 :
70
71
self .__build(0 , 0 , self .size - 1 )
71
-
72
+
72
73
# 构建线段树,节点的存储下标为 index,节点的区间为 [left, right]
73
74
def __build (self , index , left , right ):
74
75
self .tree[index].left = left
75
76
self .tree[index].right = right
76
77
if left == right: # 叶子节点,节点值为对应位置的元素值
77
78
self .tree[index].val = self .nums[left]
78
79
return
79
-
80
+
80
81
mid = left + (right - left) // 2 # 左右节点划分点
81
82
left_index = index * 2 + 1 # 左子节点的存储下标
82
83
right_index = index * 2 + 2 # 右子节点的存储下标
83
- self ._build (left_index, left, mid) # 递归创建左子树
84
- self ._build (right_index, mid + 1 , right) # 递归创建右子树
84
+ self .__build (left_index, left, mid) # 递归创建左子树
85
+ self .__build (right_index, mid + 1 , right) # 递归创建右子树
85
86
self .__pushup(index) # 向上更新节点的区间值
86
-
87
+
87
88
# 向上更新下标为 index 的节点区间值,节点的区间值等于该节点左右子节点元素值的聚合计算结果
88
89
def __pushup (self , index ):
89
90
left_index = index * 2 + 1 # 左子节点的存储下标
@@ -125,9 +126,9 @@ class SegmentTree:
125
126
mid = left + (right - left) // 2 # 左右节点划分点
126
127
left_index = index * 2 + 1 # 左子节点的存储下标
127
128
right_index = index * 2 + 2 # 右子节点的存储下标
128
- if i <= mid: # 在左子树中更新
129
+ if i <= mid: # 在左子树中更新节点值
129
130
self .__update_point(i, val, left_index, left, mid)
130
- else : # 在右子树中更新
131
+ else : # 在右子树中更新节点值
131
132
self .__update_point(i, val, right_index, mid + 1 , right)
132
133
self .__pushup(index) # 向上更新节点的区间值
133
134
```
@@ -158,7 +159,9 @@ class SegmentTree:
158
159
return self .tree[index].val # 直接返回节点值
159
160
if right < q_left or left > q_right: # 节点所在区间与 [q_left, q_right] 无关
160
161
return 0
161
-
162
+
163
+ self .__pushdown(index)
164
+
162
165
mid = left + (right - left) // 2 # 左右节点划分点
163
166
left_index = index * 2 + 1 # 左子节点的存储下标
164
167
right_index = index * 2 + 2 # 右子节点的存储下标
@@ -216,36 +219,42 @@ class SegmentTree:
216
219
def __update_interval (self , q_left , q_right , val , index , left , right ):
217
220
218
221
if left >= q_left and right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
222
+ interval_size = (right - left + 1 ) # 当前节点所在区间大小
223
+ self .tree[index].val = interval_size * val # 当前节点所在区间每个元素值改为 val
219
224
self .tree[index].lazy_tag = val # 将当前节点的延迟标记为区间值
220
- self .tree[index].val = val # 更新当前节点所在区间值
221
225
return
222
226
if right < q_left or left > q_right: # 节点所在区间与 [q_left, q_right] 无关
223
227
return 0
224
-
225
- if self .tree[index].lazy_tag: # 需要向下更新节点所在区间的左右子节点的值和懒惰标记
226
- self .__pushdown(index)
227
-
228
+
229
+ self .__pushdown(index)
230
+
228
231
mid = left + (right - left) // 2 # 左右节点划分点
229
232
left_index = index * 2 + 1 # 左子节点的存储下标
230
233
right_index = index * 2 + 2 # 右子节点的存储下标
231
234
if q_left <= mid: # 在左子树中更新区间值
232
- self ._update_interval (q_left, q_right, val, left_index, left, mid)
235
+ self .__update_interval (q_left, q_right, val, left_index, left, mid)
233
236
if q_right > mid: # 在右子树中更新区间值
234
- self ._update_interval(q_left, q_right, val, right_index, mid + 1 , right)
237
+ self .__update_interval(q_left, q_right, val, right_index, mid + 1 , right)
238
+
235
239
self .__pushup(index)
236
240
237
241
# 向下更新下标为 index 的节点所在区间的左右子节点的值和懒惰标记
238
242
def __pushdown (self , index ):
239
243
lazy_tag = self .tree[index].lazy_tag
244
+ if not lazy_tag:
245
+ return
246
+
240
247
left_index = index * 2 + 1 # 左子节点的存储下标
241
248
right_index = index * 2 + 2 # 右子节点的存储下标
242
-
243
- self .tree[left_index].lazy_tag = lazy_tag # 更新左子节点懒惰标记
244
- self .tree[left_index].val = lazy_tag # 更新左子节点值
245
-
246
- self .tree[right_index].lazy_tag = lazy_tag # 更新右子节点懒惰标记
247
- self .tree[right_index].val = lazy_tag # 更新右子节点值
248
249
250
+ self .tree[left_index].lazy_tag = lazy_tag # 更新左子节点懒惰标记
251
+ left_size = (self .tree[left_index].right - self .tree[left_index].left + 1 )
252
+ self .tree[left_index].val = lazy_tag * left_size # 更新左子节点值
253
+
254
+ self .tree[right_index].lazy_tag = lazy_tag # 更新右子节点懒惰标记
255
+ right_size = (self .tree[right_index].right - self .tree[right_index].left + 1 )
256
+ self .tree[right_index].val = lazy_tag * right_size # 更新右子节点值
257
+
249
258
self .tree[index].lazy_tag = None # 更新当前节点的懒惰标记
250
259
```
251
260
@@ -264,40 +273,56 @@ class SegmentTree:
264
273
def __update_interval (self , q_left , q_right , val , index , left , right ):
265
274
266
275
if left >= q_left and right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
267
- self .tree[index].lazy_tag += val # 将当前节点的延迟标记增加 val
276
+ # interval_size = (right - left + 1) # 当前节点所在区间大小
277
+ # self.tree[index].val = interval_size * val # 当前节点所在区间每个元素值改为 val
278
+ # self.tree[index].lazy_tag = val # 将当前节点的延迟标记为区间值
279
+
280
+ if self .tree[index].lazy_tag:
281
+ self .tree[index].lazy_tag += val # 将当前节点的延迟标记增加 val
282
+ else :
283
+ self .tree[index].lazy_tag = val # 将当前节点的延迟标记增加 val
268
284
interval_size = (right - left + 1 ) # 当前节点所在区间大小
269
- self .tree[index].val += val * interval_size # 当前节点所在区间每个元素值增加 val
285
+ self .tree[index].val += val * interval_size # 当前节点所在区间每个元素值增加 val
270
286
return
271
287
if right < q_left or left > q_right: # 节点所在区间与 [q_left, q_right] 无关
272
288
return 0
273
-
274
- if self .tree[index].lazy_tag: # 需要向下更新节点所在区间的左右子节点的值和懒惰标记
275
- self .__pushdown(index)
276
-
289
+
290
+ self .__pushdown(index)
291
+
277
292
mid = left + (right - left) // 2 # 左右节点划分点
278
293
left_index = index * 2 + 1 # 左子节点的存储下标
279
294
right_index = index * 2 + 2 # 右子节点的存储下标
280
295
if q_left <= mid: # 在左子树中更新区间值
281
- self ._update_interval (q_left, q_right, val, left_index, left, mid)
296
+ self .__update_interval (q_left, q_right, val, left_index, left, mid)
282
297
if q_right > mid: # 在右子树中更新区间值
283
- self ._update_interval(q_left, q_right, val, right_index, mid + 1 , right)
298
+ self .__update_interval(q_left, q_right, val, right_index, mid + 1 , right)
299
+
284
300
self .__pushup(index)
285
301
286
302
# 向下更新下标为 index 的节点所在区间的左右子节点的值和懒惰标记
287
303
def __pushdown (self , index ):
288
304
lazy_tag = self .tree[index].lazy_tag
305
+ if not lazy_tag:
306
+ return
307
+
289
308
left_index = index * 2 + 1 # 左子节点的存储下标
290
309
right_index = index * 2 + 2 # 右子节点的存储下标
291
310
292
- self .tree[left_index].lazy_tag += lazy_tag # 更新左子节点懒惰标记
311
+ if self .tree[left_index].lazy_tag:
312
+ self .tree[left_index].lazy_tag += lazy_tag # 更新左子节点懒惰标记
313
+ else :
314
+ self .tree[left_index].lazy_tag = lazy_tag
293
315
left_size = (self .tree[left_index].right - self .tree[left_index].left + 1 )
294
316
self .tree[left_index].val += lazy_tag * left_size # 左子节点每个元素值增加 lazy_tag
295
317
296
- self .tree[right_index].lazy_tag += lazy_tag # 更新右子节点懒惰标记
318
+ if self .tree[right_index].lazy_tag:
319
+ self .tree[right_index].lazy_tag += lazy_tag # 更新右子节点懒惰标记
320
+ else :
321
+ self .tree[right_index].lazy_tag = lazy_tag
297
322
right_size = (self .tree[right_index].right - self .tree[right_index].left + 1 )
298
323
self .tree[right_index].val += lazy_tag * right_size # 右子节点每个元素值增加 lazy_tag
299
324
300
- self .tree[index].lazy_tag = None # 更新当前节点的懒惰标记
325
+ self .tree[index].lazy_tag = None # 更新当前节点的懒惰标记
301
326
```
302
327
303
328
## 4. 线段树的常见题型
@@ -343,6 +368,152 @@ class SegmentTree:
343
368
344
369
这类问题通常坐标跨度很大,需要先对每条扫描线的坐标进行离散化处理,将 ` y ` 坐标映射到 ` 0, 1, 2, ... ` 中。然后将每条竖线的端点作为区间范围,使用线段树存储每条竖线的信息(` x ` 坐标、是左竖线还是右竖线等),然后再进行区间合并,并统计相关信息。
345
370
371
+ ## 5. 线段树的拓展
372
+
373
+ ### 5.1 动态开点线段树
374
+
375
+ 在有些情况下,线段树需要维护的区间很大(例如 $[ 1, 10^9] $),在实际中用到的节点却很少。
376
+
377
+ 如果使用之前数组形式实现线段树,则需要 $4 * n$ 大小的空间,空间消耗有点过大了。
378
+
379
+ 这时候我们就可以使用动态开点的思想来构建线段树。
380
+
381
+ 动态开点线段树的算法思想如下:
382
+
383
+ - 开始时只建立一个根节点,代表整个区间。
384
+ - 当需要访问线段树的某棵子树(某个子区间)时,再建立代表这个子区间的节点。
385
+
386
+ 动态开点线段树实现代码如下:
387
+
388
+ ``` Python
389
+ # 线段树的节点类
390
+ class TreeNode :
391
+ def __init__ (self , left = - 1 , right = - 1 , val = 0 ):
392
+ self .left = left # 区间左边界
393
+ self .right = right # 区间右边界
394
+ self .mid = left + (right - left) // 2
395
+ self .leftNode = None # 区间左节点
396
+ self .rightNode = None # 区间右节点
397
+ self .val = val # 节点值(区间值)
398
+ self .lazy_tag = None # 区间问题的延迟更新标记
399
+
400
+
401
+ # 线段树类
402
+ class SegmentTree :
403
+ def __init__ (self , function ):
404
+ self .tree = TreeNode(0 , int (1e9 ))
405
+ self .function = function # function 是一个函数,左右区间的聚合方法
406
+
407
+ # 向上更新 node 节点区间值,节点的区间值等于该节点左右子节点元素值的聚合计算结果
408
+ def __pushup (self , node ):
409
+ leftNode = node.leftNode
410
+ rightNode = node.rightNode
411
+ if leftNode and rightNode:
412
+ node.val = self .function(leftNode.val, rightNode.val)
413
+
414
+ # 单点更新,将 nums[i] 更改为 val
415
+ def update_point (self , i , val ):
416
+ self .__update_point(i, val, self .tree)
417
+
418
+ # 单点更新,将 nums[i] 更改为 val。node 节点的区间为 [node.left, node.right]
419
+ def __update_point (self , i , val , node ):
420
+ if node.left == node.right:
421
+ node.val = val # 叶子节点,节点值修改为 val
422
+ return
423
+
424
+ if i <= node.mid: # 在左子树中更新节点值
425
+ if not node.leftNode:
426
+ node.leftNode = TreeNode(node.left, node.mid)
427
+ self .__update_point(i, val, node.leftNode)
428
+ else : # 在右子树中更新节点值
429
+ if not node.rightNode:
430
+ node.rightNode = TreeNode(node.mid + 1 , node.right)
431
+ self .__update_point(i, val, node.rightNode)
432
+ self .__pushup(node) # 向上更新节点的区间值
433
+
434
+ # 区间查询,查询区间为 [q_left, q_right] 的区间值
435
+ def query_interval (self , q_left , q_right ):
436
+ return self .__query_interval(q_left, q_right, self .tree)
437
+
438
+ # 区间查询,在线段树的 [left, right] 区间范围中搜索区间为 [q_left, q_right] 的区间值
439
+ def __query_interval (self , q_left , q_right , node ):
440
+ if node.left >= q_left and node.right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
441
+ return node.val # 直接返回节点值
442
+ if node.right < q_left or node.left > q_right: # 节点所在区间与 [q_left, q_right] 无关
443
+ return 0
444
+
445
+ self .__pushdown(node) # 向下更新节点所在区间的左右子节点的值和懒惰标记
446
+
447
+ res_left = 0 # 左子树查询结果
448
+ res_right = 0 # 右子树查询结果
449
+ if q_left <= node.mid: # 在左子树中查询
450
+ if not node.leftNode:
451
+ node.leftNode = TreeNode(node.left, node.mid)
452
+ res_left = self .__query_interval(q_left, q_right, node.leftNode)
453
+ if q_right > node.mid: # 在右子树中查询
454
+ if not node.rightNode:
455
+ node.rightNode = TreeNode(node.mid + 1 , node.right)
456
+ res_right = self .__query_interval(q_left, q_right, node.rightNode)
457
+ return self .function(res_left, res_right) # 返回左右子树元素值的聚合计算结果
458
+
459
+ # 区间更新,将区间为 [q_left, q_right] 上的元素值修改为 val
460
+ def update_interval (self , q_left , q_right , val ):
461
+ self .__update_interval(q_left, q_right, val, self .tree)
462
+
463
+ # 区间更新
464
+ def __update_interval (self , q_left , q_right , val , node ):
465
+ if node.left >= q_left and node.right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
466
+ if node.lazy_tag:
467
+ node.lazy_tag += val # 将当前节点的延迟标记增加 val
468
+ else :
469
+ node.lazy_tag = val # 将当前节点的延迟标记增加 val
470
+ interval_size = (node.right - node.left + 1 ) # 当前节点所在区间大小
471
+ node.val += val * interval_size # 当前节点所在区间每个元素值增加 val
472
+ return
473
+ if node.right < q_left or node.left > q_right: # 节点所在区间与 [q_left, q_right] 无关
474
+ return 0
475
+
476
+ self .__pushdown(node) # 向下更新节点所在区间的左右子节点的值和懒惰标记
477
+
478
+ if q_left <= node.mid: # 在左子树中更新区间值
479
+ if not node.leftNode:
480
+ node.leftNode = TreeNode(node.left, node.mid)
481
+ self .__update_interval(q_left, q_right, val, node.leftNode)
482
+ if q_right > node.mid: # 在右子树中更新区间值
483
+ if not node.rightNode:
484
+ node.rightNode = TreeNode(node.mid + 1 , node.right)
485
+ self .__update_interval(q_left, q_right, val, node.rightNode)
486
+
487
+ self .__pushup(node)
488
+
489
+ # 向下更新 node 节点所在区间的左右子节点的值和懒惰标记
490
+ def __pushdown (self , node ):
491
+ lazy_tag = node.lazy_tag
492
+ if not node.lazy_tag:
493
+ return
494
+
495
+ if not node.leftNode:
496
+ node.leftNode = TreeNode(node.left, node.mid)
497
+ if not node.rightNode:
498
+ node.rightNode = TreeNode(node.mid + 1 , node.right)
499
+
500
+ if node.leftNode.lazy_tag:
501
+ node.leftNode.lazy_tag += lazy_tag # 更新左子节点懒惰标记
502
+ else :
503
+ node.leftNode.lazy_tag = lazy_tag # 更新左子节点懒惰标记
504
+ left_size = (node.leftNode.right - node.leftNode.left + 1 )
505
+ node.leftNode.val += lazy_tag * left_size # 左子节点每个元素值增加 lazy_tag
506
+
507
+ if node.rightNode.lazy_tag:
508
+ node.rightNode.lazy_tag += lazy_tag # 更新右子节点懒惰标记
509
+ else :
510
+ node.rightNode.lazy_tag = lazy_tag # 更新右子节点懒惰标记
511
+ right_size = (node.rightNode.right - node.rightNode.left + 1 )
512
+ node.rightNode.val += lazy_tag * right_size # 右子节点每个元素值增加 lazy_tag
513
+
514
+ node.lazy_tag = None # 更新当前节点的懒惰标记
515
+ ```
516
+
346
517
## 参考资料
347
518
348
519
- 【书籍】ACM-ICPC 程序设计系列 - 算法设计与实现 - 陈宇 吴昊 主编
0 commit comments