@@ -6,31 +6,32 @@ def __init__(self, val=0):
6
6
self .val = val # 节点值(区间值)
7
7
self .lazy_tag = None # 区间和问题的延迟更新标记
8
8
9
+
9
10
# 线段树类
10
11
class SegmentTree :
11
12
def __init__ (self , nums , function ):
12
13
self .size = len (nums )
13
- self .tree = [TreeNode () for _ in range (4 * self .size )] # 维护 TreeNode 数组
14
+ self .tree = [TreeNode () for _ in range (4 * self .size )] # 维护 TreeNode 数组
14
15
self .nums = nums # 原始数据
15
16
self .function = function # function 是一个函数,左右区间的聚合方法
16
17
if self .size > 0 :
17
18
self .__build (0 , 0 , self .size - 1 )
18
-
19
+
19
20
# 构建线段树,节点的存储下标为 index,节点的区间为 [left, right]
20
21
def __build (self , index , left , right ):
21
22
self .tree [index ].left = left
22
23
self .tree [index ].right = right
23
24
if left == right : # 叶子节点,节点值为对应位置的元素值
24
25
self .tree [index ].val = self .nums [left ]
25
26
return
26
-
27
+
27
28
mid = left + (right - left ) // 2 # 左右节点划分点
28
29
left_index = index * 2 + 1 # 左子节点的存储下标
29
30
right_index = index * 2 + 2 # 右子节点的存储下标
30
- self ._build (left_index , left , mid ) # 递归创建左子树
31
- self ._build (right_index , mid + 1 , right ) # 递归创建右子树
31
+ self .__build (left_index , left , mid ) # 递归创建左子树
32
+ self .__build (right_index , mid + 1 , right ) # 递归创建右子树
32
33
self .__pushup (index ) # 向上更新节点的区间值
33
-
34
+
34
35
# 向上更新下标为 index 的节点区间值,节点的区间值等于该节点左右子节点元素值的聚合计算结果
35
36
def __pushup (self , index ):
36
37
left_index = index * 2 + 1 # 左子节点的存储下标
@@ -56,7 +57,7 @@ def __update_point(self, i, val, index, left, right):
56
57
else : # 在右子树中更新节点值
57
58
self .__update_point (i , val , right_index , mid + 1 , right )
58
59
self .__pushup (index ) # 向上更新节点的区间值
59
-
60
+
60
61
# 区间查询,查询区间为 [q_left, q_right] 的区间值
61
62
def query_interval (self , q_left , q_right ):
62
63
return self .__query_interval (q_left , q_right , 0 , 0 , self .size - 1 )
@@ -67,7 +68,9 @@ def __query_interval(self, q_left, q_right, index, left, right):
67
68
return self .tree [index ].val # 直接返回节点值
68
69
if right < q_left or left > q_right : # 节点所在区间与 [q_left, q_right] 无关
69
70
return 0
70
-
71
+
72
+ self .__pushdown (index )
73
+
71
74
mid = left + (right - left ) // 2 # 左右节点划分点
72
75
left_index = index * 2 + 1 # 左子节点的存储下标
73
76
right_index = index * 2 + 2 # 右子节点的存储下标
@@ -87,48 +90,69 @@ def update_interval(self, q_left, q_right, val):
87
90
def __update_interval (self , q_left , q_right , val , index , left , right ):
88
91
89
92
if left >= q_left and right <= q_right : # 节点所在区间被 [q_left, q_right] 所覆盖
93
+ # interval_size = (right - left + 1) # 当前节点所在区间大小
94
+ # self.tree[index].val = interval_size * val # 当前节点所在区间每个元素值改为 val
90
95
# self.tree[index].lazy_tag = val # 将当前节点的延迟标记为区间值
91
- # self.tree[index].val = val # 更新当前节点所在区间值
92
-
93
- self .tree [index ].lazy_tag += val # 将当前节点的延迟标记增加 val
96
+
97
+ if self .tree [index ].lazy_tag :
98
+ self .tree [index ].lazy_tag += val # 将当前节点的延迟标记增加 val
99
+ else :
100
+ self .tree [index ].lazy_tag = val # 将当前节点的延迟标记增加 val
94
101
interval_size = (right - left + 1 ) # 当前节点所在区间大小
95
- self .tree [index ].val += val * interval_size # 当前节点所在区间每个元素值增加 val
102
+ self .tree [index ].val += val * interval_size # 当前节点所在区间每个元素值增加 val
96
103
return
97
104
if right < q_left or left > q_right : # 节点所在区间与 [q_left, q_right] 无关
98
105
return 0
99
-
100
- if self .tree [index ].lazy_tag : # 需要向下更新节点所在区间的左右子节点的值和懒惰标记
101
- self .__pushdown (index )
102
-
106
+
107
+ self .__pushdown (index )
108
+
103
109
mid = left + (right - left ) // 2 # 左右节点划分点
104
110
left_index = index * 2 + 1 # 左子节点的存储下标
105
111
right_index = index * 2 + 2 # 右子节点的存储下标
106
112
if q_left <= mid : # 在左子树中更新区间值
107
- self ._update_interval (q_left , q_right , val , left_index , left , mid )
113
+ self .__update_interval (q_left , q_right , val , left_index , left , mid )
108
114
if q_right > mid : # 在右子树中更新区间值
109
- self ._update_interval (q_left , q_right , val , right_index , mid + 1 , right )
115
+ self .__update_interval (q_left , q_right , val , right_index , mid + 1 , right )
116
+
110
117
self .__pushup (index )
111
118
112
119
# 向下更新下标为 index 的节点所在区间的左右子节点的值和懒惰标记
113
120
def __pushdown (self , index ):
114
121
lazy_tag = self .tree [index ].lazy_tag
122
+ if not lazy_tag :
123
+ return
124
+
115
125
left_index = index * 2 + 1 # 左子节点的存储下标
116
126
right_index = index * 2 + 2 # 右子节点的存储下标
117
127
118
128
# self.tree[left_index].lazy_tag = lazy_tag # 更新左子节点懒惰标记
119
- # self.tree[left_index].val = lazy_tag # 更新左子节点值
120
- #
129
+ # left_size = (self.tree[left_index].right - self.tree[left_index].left + 1)
130
+ # self.tree[left_index].val = lazy_tag * left_size # 更新左子节点值
131
+ #
121
132
# self.tree[right_index].lazy_tag = lazy_tag # 更新右子节点懒惰标记
122
- # self.tree[right_index].val = lazy_tag # 更新右子节点值
123
- #
133
+ # right_size = (self.tree[right_index].right - self.tree[right_index].left + 1)
134
+ # self.tree[right_index].val = lazy_tag * right_size # 更新右子节点值
135
+ #
124
136
# self.tree[index].lazy_tag = None # 更新当前节点的懒惰标记
125
137
126
- self .tree [left_index ].lazy_tag += lazy_tag # 更新左子节点懒惰标记
138
+ if self .tree [left_index ].lazy_tag :
139
+ self .tree [left_index ].lazy_tag += lazy_tag # 更新左子节点懒惰标记
140
+ else :
141
+ self .tree [left_index ].lazy_tag = lazy_tag
127
142
left_size = (self .tree [left_index ].right - self .tree [left_index ].left + 1 )
128
143
self .tree [left_index ].val += lazy_tag * left_size # 左子节点每个元素值增加 lazy_tag
129
144
130
- self .tree [right_index ].lazy_tag += lazy_tag # 更新右子节点懒惰标记
145
+ if self .tree [right_index ].lazy_tag :
146
+ self .tree [right_index ].lazy_tag += lazy_tag # 更新右子节点懒惰标记
147
+ else :
148
+ self .tree [right_index ].lazy_tag = lazy_tag
131
149
right_size = (self .tree [right_index ].right - self .tree [right_index ].left + 1 )
132
150
self .tree [right_index ].val += lazy_tag * right_size # 右子节点每个元素值增加 lazy_tag
133
151
134
- self .tree [index ].lazy_tag = None # 更新当前节点的懒惰标记
152
+ self .tree [index ].lazy_tag = None # 更新当前节点的懒惰标记
153
+
154
+ # 获取 nums 数组
155
+ def get_nums (self ):
156
+ for i in range (self .size ):
157
+ self .nums [i ] = self .query_interval (i , i )
158
+ return self .nums
0 commit comments