Skip to content

Commit 7b8b728

Browse files
committed
Create Tree-DynamicSegmentTree.py
1 parent 0658e2c commit 7b8b728

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# 线段树的节点类
2+
class TreeNode:
3+
def __init__(self, left=-1, right=-1, val=0):
4+
self.left = left # 区间左边界
5+
self.right = right # 区间右边界
6+
self.mid = left + (right - left) // 2
7+
self.leftNode = None # 区间左节点
8+
self.rightNode = None # 区间右节点
9+
self.val = val # 节点值(区间值)
10+
self.lazy_tag = None # 区间问题的延迟更新标记
11+
12+
13+
# 线段树类
14+
class SegmentTree:
15+
def __init__(self, function):
16+
self.tree = TreeNode(0, int(1e9))
17+
self.function = function # function 是一个函数,左右区间的聚合方法
18+
19+
# 向上更新 node 节点区间值,节点的区间值等于该节点左右子节点元素值的聚合计算结果
20+
def __pushup(self, node):
21+
leftNode = node.leftNode
22+
rightNode = node.rightNode
23+
if leftNode and rightNode:
24+
node.val = self.function(leftNode.val, rightNode.val)
25+
26+
# 单点更新,将 nums[i] 更改为 val
27+
def update_point(self, i, val):
28+
self.__update_point(i, val, self.tree)
29+
30+
# 单点更新,将 nums[i] 更改为 val。node 节点的区间为 [node.left, node.right]
31+
def __update_point(self, i, val, node):
32+
if node.left == node.right:
33+
node.val = val # 叶子节点,节点值修改为 val
34+
return
35+
36+
if i <= node.mid: # 在左子树中更新节点值
37+
if not node.leftNode:
38+
node.leftNode = TreeNode(node.left, node.mid)
39+
self.__update_point(i, val, node.leftNode)
40+
else: # 在右子树中更新节点值
41+
if not node.rightNode:
42+
node.rightNode = TreeNode(node.mid + 1, node.right)
43+
self.__update_point(i, val, node.rightNode)
44+
self.__pushup(node) # 向上更新节点的区间值
45+
46+
# 区间查询,查询区间为 [q_left, q_right] 的区间值
47+
def query_interval(self, q_left, q_right):
48+
return self.__query_interval(q_left, q_right, self.tree)
49+
50+
# 区间查询,在线段树的 [left, right] 区间范围中搜索区间为 [q_left, q_right] 的区间值
51+
def __query_interval(self, q_left, q_right, node):
52+
if node.left >= q_left and node.right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
53+
return node.val # 直接返回节点值
54+
if node.right < q_left or node.left > q_right: # 节点所在区间与 [q_left, q_right] 无关
55+
return 0
56+
57+
self.__pushdown(node) # 向下更新节点所在区间的左右子节点的值和懒惰标记
58+
59+
res_left = 0 # 左子树查询结果
60+
res_right = 0 # 右子树查询结果
61+
if q_left <= node.mid: # 在左子树中查询
62+
if not node.leftNode:
63+
node.leftNode = TreeNode(node.left, node.mid)
64+
res_left = self.__query_interval(q_left, q_right, node.leftNode)
65+
if q_right > node.mid: # 在右子树中查询
66+
if not node.rightNode:
67+
node.rightNode = TreeNode(node.mid + 1, node.right)
68+
res_right = self.__query_interval(q_left, q_right, node.rightNode)
69+
return self.function(res_left, res_right) # 返回左右子树元素值的聚合计算结果
70+
71+
# 区间更新,将区间为 [q_left, q_right] 上的元素值修改为 val
72+
def update_interval(self, q_left, q_right, val):
73+
self.__update_interval(q_left, q_right, val, self.tree)
74+
75+
# 区间更新
76+
def __update_interval(self, q_left, q_right, val, node):
77+
if node.left >= q_left and node.right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
78+
if node.lazy_tag:
79+
node.lazy_tag += val # 将当前节点的延迟标记增加 val
80+
else:
81+
node.lazy_tag = val # 将当前节点的延迟标记增加 val
82+
interval_size = (node.right - node.left + 1) # 当前节点所在区间大小
83+
node.val += val * interval_size # 当前节点所在区间每个元素值增加 val
84+
return
85+
if node.right < q_left or node.left > q_right: # 节点所在区间与 [q_left, q_right] 无关
86+
return 0
87+
88+
self.__pushdown(node) # 向下更新节点所在区间的左右子节点的值和懒惰标记
89+
90+
if q_left <= node.mid: # 在左子树中更新区间值
91+
if not node.leftNode:
92+
node.leftNode = TreeNode(node.left, node.mid)
93+
self.__update_interval(q_left, q_right, val, node.leftNode)
94+
if q_right > node.mid: # 在右子树中更新区间值
95+
if not node.rightNode:
96+
node.rightNode = TreeNode(node.mid + 1, node.right)
97+
self.__update_interval(q_left, q_right, val, node.rightNode)
98+
99+
self.__pushup(node)
100+
101+
# 向下更新 node 节点所在区间的左右子节点的值和懒惰标记
102+
def __pushdown(self, node):
103+
lazy_tag = node.lazy_tag
104+
if not node.lazy_tag:
105+
return
106+
107+
if not node.leftNode:
108+
node.leftNode = TreeNode(node.left, node.mid)
109+
if not node.rightNode:
110+
node.rightNode = TreeNode(node.mid + 1, node.right)
111+
112+
if node.leftNode.lazy_tag:
113+
node.leftNode.lazy_tag += lazy_tag # 更新左子节点懒惰标记
114+
else:
115+
node.leftNode.lazy_tag = lazy_tag # 更新左子节点懒惰标记
116+
left_size = (node.leftNode.right - node.leftNode.left + 1)
117+
node.leftNode.val += lazy_tag * left_size # 左子节点每个元素值增加 lazy_tag
118+
119+
if node.rightNode.lazy_tag:
120+
node.rightNode.lazy_tag += lazy_tag # 更新右子节点懒惰标记
121+
else:
122+
node.rightNode.lazy_tag = lazy_tag # 更新右子节点懒惰标记
123+
right_size = (node.rightNode.right - node.rightNode.left + 1)
124+
node.rightNode.val += lazy_tag * right_size # 右子节点每个元素值增加 lazy_tag
125+
126+
node.lazy_tag = None # 更新当前节点的懒惰标记
127+
128+
def get_nums(self, length):
129+
nums = [0 for _ in range(length)]
130+
for i in range(length):
131+
nums[i] = self.query_interval(i, i)
132+
return nums

0 commit comments

Comments
 (0)