1
+ # 线段树的节点类
2
+ class TreeNode :
3
+ def __init__ (self , left = - 1 , right = - 1 , val = 0 ):
4
+ self .left = left # 区间左边界
5
+ self .right = right # 区间右边界
6
+ self .mid = left + (right - left ) // 2
7
+ self .leftNode = None # 区间左节点
8
+ self .rightNode = None # 区间右节点
9
+ self .val = val # 节点值(区间值)
10
+ self .lazy_tag = None # 区间问题的延迟更新标记
11
+
12
+
13
+ # 线段树类
14
+ class SegmentTree :
15
+ def __init__ (self , function ):
16
+ self .tree = TreeNode (0 , int (1e9 ))
17
+ self .function = function # function 是一个函数,左右区间的聚合方法
18
+
19
+ # 向上更新 node 节点区间值,节点的区间值等于该节点左右子节点元素值的聚合计算结果
20
+ def __pushup (self , node ):
21
+ leftNode = node .leftNode
22
+ rightNode = node .rightNode
23
+ if leftNode and rightNode :
24
+ node .val = self .function (leftNode .val , rightNode .val )
25
+
26
+ # 单点更新,将 nums[i] 更改为 val
27
+ def update_point (self , i , val ):
28
+ self .__update_point (i , val , self .tree )
29
+
30
+ # 单点更新,将 nums[i] 更改为 val。node 节点的区间为 [node.left, node.right]
31
+ def __update_point (self , i , val , node ):
32
+ if node .left == node .right :
33
+ node .val = val # 叶子节点,节点值修改为 val
34
+ return
35
+
36
+ if i <= node .mid : # 在左子树中更新节点值
37
+ if not node .leftNode :
38
+ node .leftNode = TreeNode (node .left , node .mid )
39
+ self .__update_point (i , val , node .leftNode )
40
+ else : # 在右子树中更新节点值
41
+ if not node .rightNode :
42
+ node .rightNode = TreeNode (node .mid + 1 , node .right )
43
+ self .__update_point (i , val , node .rightNode )
44
+ self .__pushup (node ) # 向上更新节点的区间值
45
+
46
+ # 区间查询,查询区间为 [q_left, q_right] 的区间值
47
+ def query_interval (self , q_left , q_right ):
48
+ return self .__query_interval (q_left , q_right , self .tree )
49
+
50
+ # 区间查询,在线段树的 [left, right] 区间范围中搜索区间为 [q_left, q_right] 的区间值
51
+ def __query_interval (self , q_left , q_right , node ):
52
+ if node .left >= q_left and node .right <= q_right : # 节点所在区间被 [q_left, q_right] 所覆盖
53
+ return node .val # 直接返回节点值
54
+ if node .right < q_left or node .left > q_right : # 节点所在区间与 [q_left, q_right] 无关
55
+ return 0
56
+
57
+ self .__pushdown (node ) # 向下更新节点所在区间的左右子节点的值和懒惰标记
58
+
59
+ res_left = 0 # 左子树查询结果
60
+ res_right = 0 # 右子树查询结果
61
+ if q_left <= node .mid : # 在左子树中查询
62
+ if not node .leftNode :
63
+ node .leftNode = TreeNode (node .left , node .mid )
64
+ res_left = self .__query_interval (q_left , q_right , node .leftNode )
65
+ if q_right > node .mid : # 在右子树中查询
66
+ if not node .rightNode :
67
+ node .rightNode = TreeNode (node .mid + 1 , node .right )
68
+ res_right = self .__query_interval (q_left , q_right , node .rightNode )
69
+ return self .function (res_left , res_right ) # 返回左右子树元素值的聚合计算结果
70
+
71
+ # 区间更新,将区间为 [q_left, q_right] 上的元素值修改为 val
72
+ def update_interval (self , q_left , q_right , val ):
73
+ self .__update_interval (q_left , q_right , val , self .tree )
74
+
75
+ # 区间更新
76
+ def __update_interval (self , q_left , q_right , val , node ):
77
+ if node .left >= q_left and node .right <= q_right : # 节点所在区间被 [q_left, q_right] 所覆盖
78
+ if node .lazy_tag :
79
+ node .lazy_tag += val # 将当前节点的延迟标记增加 val
80
+ else :
81
+ node .lazy_tag = val # 将当前节点的延迟标记增加 val
82
+ interval_size = (node .right - node .left + 1 ) # 当前节点所在区间大小
83
+ node .val += val * interval_size # 当前节点所在区间每个元素值增加 val
84
+ return
85
+ if node .right < q_left or node .left > q_right : # 节点所在区间与 [q_left, q_right] 无关
86
+ return 0
87
+
88
+ self .__pushdown (node ) # 向下更新节点所在区间的左右子节点的值和懒惰标记
89
+
90
+ if q_left <= node .mid : # 在左子树中更新区间值
91
+ if not node .leftNode :
92
+ node .leftNode = TreeNode (node .left , node .mid )
93
+ self .__update_interval (q_left , q_right , val , node .leftNode )
94
+ if q_right > node .mid : # 在右子树中更新区间值
95
+ if not node .rightNode :
96
+ node .rightNode = TreeNode (node .mid + 1 , node .right )
97
+ self .__update_interval (q_left , q_right , val , node .rightNode )
98
+
99
+ self .__pushup (node )
100
+
101
+ # 向下更新 node 节点所在区间的左右子节点的值和懒惰标记
102
+ def __pushdown (self , node ):
103
+ lazy_tag = node .lazy_tag
104
+ if not node .lazy_tag :
105
+ return
106
+
107
+ if not node .leftNode :
108
+ node .leftNode = TreeNode (node .left , node .mid )
109
+ if not node .rightNode :
110
+ node .rightNode = TreeNode (node .mid + 1 , node .right )
111
+
112
+ if node .leftNode .lazy_tag :
113
+ node .leftNode .lazy_tag += lazy_tag # 更新左子节点懒惰标记
114
+ else :
115
+ node .leftNode .lazy_tag = lazy_tag # 更新左子节点懒惰标记
116
+ left_size = (node .leftNode .right - node .leftNode .left + 1 )
117
+ node .leftNode .val += lazy_tag * left_size # 左子节点每个元素值增加 lazy_tag
118
+
119
+ if node .rightNode .lazy_tag :
120
+ node .rightNode .lazy_tag += lazy_tag # 更新右子节点懒惰标记
121
+ else :
122
+ node .rightNode .lazy_tag = lazy_tag # 更新右子节点懒惰标记
123
+ right_size = (node .rightNode .right - node .rightNode .left + 1 )
124
+ node .rightNode .val += lazy_tag * right_size # 右子节点每个元素值增加 lazy_tag
125
+
126
+ node .lazy_tag = None # 更新当前节点的懒惰标记
127
+
128
+ def get_nums (self , length ):
129
+ nums = [0 for _ in range (length )]
130
+ for i in range (length ):
131
+ nums [i ] = self .query_interval (i , i )
132
+ return nums
0 commit comments