Skip to content

Commit 0658e2c

Browse files
committed
Update Tree-SegmentTree.py
1 parent 5e3b406 commit 0658e2c

File tree

1 file changed

+49
-25
lines changed

1 file changed

+49
-25
lines changed

Templates/07.Tree/Tree-SegmentTree.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,32 @@ def __init__(self, val=0):
66
self.val = val # 节点值(区间值)
77
self.lazy_tag = None # 区间和问题的延迟更新标记
88

9+
910
# 线段树类
1011
class SegmentTree:
1112
def __init__(self, nums, function):
1213
self.size = len(nums)
13-
self.tree = [TreeNode() for _ in range(4 * self.size)] # 维护 TreeNode 数组
14+
self.tree = [TreeNode() for _ in range(4 * self.size)] # 维护 TreeNode 数组
1415
self.nums = nums # 原始数据
1516
self.function = function # function 是一个函数,左右区间的聚合方法
1617
if self.size > 0:
1718
self.__build(0, 0, self.size - 1)
18-
19+
1920
# 构建线段树,节点的存储下标为 index,节点的区间为 [left, right]
2021
def __build(self, index, left, right):
2122
self.tree[index].left = left
2223
self.tree[index].right = right
2324
if left == right: # 叶子节点,节点值为对应位置的元素值
2425
self.tree[index].val = self.nums[left]
2526
return
26-
27+
2728
mid = left + (right - left) // 2 # 左右节点划分点
2829
left_index = index * 2 + 1 # 左子节点的存储下标
2930
right_index = index * 2 + 2 # 右子节点的存储下标
30-
self._build(left_index, left, mid) # 递归创建左子树
31-
self._build(right_index, mid + 1, right) # 递归创建右子树
31+
self.__build(left_index, left, mid) # 递归创建左子树
32+
self.__build(right_index, mid + 1, right) # 递归创建右子树
3233
self.__pushup(index) # 向上更新节点的区间值
33-
34+
3435
# 向上更新下标为 index 的节点区间值,节点的区间值等于该节点左右子节点元素值的聚合计算结果
3536
def __pushup(self, index):
3637
left_index = index * 2 + 1 # 左子节点的存储下标
@@ -56,7 +57,7 @@ def __update_point(self, i, val, index, left, right):
5657
else: # 在右子树中更新节点值
5758
self.__update_point(i, val, right_index, mid + 1, right)
5859
self.__pushup(index) # 向上更新节点的区间值
59-
60+
6061
# 区间查询,查询区间为 [q_left, q_right] 的区间值
6162
def query_interval(self, q_left, q_right):
6263
return self.__query_interval(q_left, q_right, 0, 0, self.size - 1)
@@ -67,7 +68,9 @@ def __query_interval(self, q_left, q_right, index, left, right):
6768
return self.tree[index].val # 直接返回节点值
6869
if right < q_left or left > q_right: # 节点所在区间与 [q_left, q_right] 无关
6970
return 0
70-
71+
72+
self.__pushdown(index)
73+
7174
mid = left + (right - left) // 2 # 左右节点划分点
7275
left_index = index * 2 + 1 # 左子节点的存储下标
7376
right_index = index * 2 + 2 # 右子节点的存储下标
@@ -87,48 +90,69 @@ def update_interval(self, q_left, q_right, val):
8790
def __update_interval(self, q_left, q_right, val, index, left, right):
8891

8992
if left >= q_left and right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
93+
# interval_size = (right - left + 1) # 当前节点所在区间大小
94+
# self.tree[index].val = interval_size * val # 当前节点所在区间每个元素值改为 val
9095
# self.tree[index].lazy_tag = val # 将当前节点的延迟标记为区间值
91-
# self.tree[index].val = val # 更新当前节点所在区间值
92-
93-
self.tree[index].lazy_tag += val # 将当前节点的延迟标记增加 val
96+
97+
if self.tree[index].lazy_tag:
98+
self.tree[index].lazy_tag += val # 将当前节点的延迟标记增加 val
99+
else:
100+
self.tree[index].lazy_tag = val # 将当前节点的延迟标记增加 val
94101
interval_size = (right - left + 1) # 当前节点所在区间大小
95-
self.tree[index].val += val * interval_size # 当前节点所在区间每个元素值增加 val
102+
self.tree[index].val += val * interval_size # 当前节点所在区间每个元素值增加 val
96103
return
97104
if right < q_left or left > q_right: # 节点所在区间与 [q_left, q_right] 无关
98105
return 0
99-
100-
if self.tree[index].lazy_tag: # 需要向下更新节点所在区间的左右子节点的值和懒惰标记
101-
self.__pushdown(index)
102-
106+
107+
self.__pushdown(index)
108+
103109
mid = left + (right - left) // 2 # 左右节点划分点
104110
left_index = index * 2 + 1 # 左子节点的存储下标
105111
right_index = index * 2 + 2 # 右子节点的存储下标
106112
if q_left <= mid: # 在左子树中更新区间值
107-
self._update_interval(q_left, q_right, val, left_index, left, mid)
113+
self.__update_interval(q_left, q_right, val, left_index, left, mid)
108114
if q_right > mid: # 在右子树中更新区间值
109-
self._update_interval(q_left, q_right, val, right_index, mid + 1, right)
115+
self.__update_interval(q_left, q_right, val, right_index, mid + 1, right)
116+
110117
self.__pushup(index)
111118

112119
# 向下更新下标为 index 的节点所在区间的左右子节点的值和懒惰标记
113120
def __pushdown(self, index):
114121
lazy_tag = self.tree[index].lazy_tag
122+
if not lazy_tag:
123+
return
124+
115125
left_index = index * 2 + 1 # 左子节点的存储下标
116126
right_index = index * 2 + 2 # 右子节点的存储下标
117127

118128
# self.tree[left_index].lazy_tag = lazy_tag # 更新左子节点懒惰标记
119-
# self.tree[left_index].val = lazy_tag # 更新左子节点值
120-
#
129+
# left_size = (self.tree[left_index].right - self.tree[left_index].left + 1)
130+
# self.tree[left_index].val = lazy_tag * left_size # 更新左子节点值
131+
#
121132
# self.tree[right_index].lazy_tag = lazy_tag # 更新右子节点懒惰标记
122-
# self.tree[right_index].val = lazy_tag # 更新右子节点值
123-
#
133+
# right_size = (self.tree[right_index].right - self.tree[right_index].left + 1)
134+
# self.tree[right_index].val = lazy_tag * right_size # 更新右子节点值
135+
#
124136
# self.tree[index].lazy_tag = None # 更新当前节点的懒惰标记
125137

126-
self.tree[left_index].lazy_tag += lazy_tag # 更新左子节点懒惰标记
138+
if self.tree[left_index].lazy_tag:
139+
self.tree[left_index].lazy_tag += lazy_tag # 更新左子节点懒惰标记
140+
else:
141+
self.tree[left_index].lazy_tag = lazy_tag
127142
left_size = (self.tree[left_index].right - self.tree[left_index].left + 1)
128143
self.tree[left_index].val += lazy_tag * left_size # 左子节点每个元素值增加 lazy_tag
129144

130-
self.tree[right_index].lazy_tag += lazy_tag # 更新右子节点懒惰标记
145+
if self.tree[right_index].lazy_tag:
146+
self.tree[right_index].lazy_tag += lazy_tag # 更新右子节点懒惰标记
147+
else:
148+
self.tree[right_index].lazy_tag = lazy_tag
131149
right_size = (self.tree[right_index].right - self.tree[right_index].left + 1)
132150
self.tree[right_index].val += lazy_tag * right_size # 右子节点每个元素值增加 lazy_tag
133151

134-
self.tree[index].lazy_tag = None # 更新当前节点的懒惰标记
152+
self.tree[index].lazy_tag = None # 更新当前节点的懒惰标记
153+
154+
# 获取 nums 数组
155+
def get_nums(self):
156+
for i in range(self.size):
157+
self.nums[i] = self.query_interval(i, i)
158+
return self.nums

0 commit comments

Comments
 (0)