Skip to content

Commit

Permalink
Create Tree-DynamicSegmentTree.py
Browse files Browse the repository at this point in the history
  • Loading branch information
itcharge committed Mar 6, 2022
1 parent 0658e2c commit 7b8b728
Showing 1 changed file with 132 additions and 0 deletions.
132 changes: 132 additions & 0 deletions Templates/07.Tree/Tree-DynamicSegmentTree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# 线段树的节点类
class TreeNode:
def __init__(self, left=-1, right=-1, val=0):
self.left = left # 区间左边界
self.right = right # 区间右边界
self.mid = left + (right - left) // 2
self.leftNode = None # 区间左节点
self.rightNode = None # 区间右节点
self.val = val # 节点值(区间值)
self.lazy_tag = None # 区间问题的延迟更新标记


# 线段树类
class SegmentTree:
def __init__(self, function):
self.tree = TreeNode(0, int(1e9))
self.function = function # function 是一个函数,左右区间的聚合方法

# 向上更新 node 节点区间值,节点的区间值等于该节点左右子节点元素值的聚合计算结果
def __pushup(self, node):
leftNode = node.leftNode
rightNode = node.rightNode
if leftNode and rightNode:
node.val = self.function(leftNode.val, rightNode.val)

# 单点更新,将 nums[i] 更改为 val
def update_point(self, i, val):
self.__update_point(i, val, self.tree)

# 单点更新,将 nums[i] 更改为 val。node 节点的区间为 [node.left, node.right]
def __update_point(self, i, val, node):
if node.left == node.right:
node.val = val # 叶子节点,节点值修改为 val
return

if i <= node.mid: # 在左子树中更新节点值
if not node.leftNode:
node.leftNode = TreeNode(node.left, node.mid)
self.__update_point(i, val, node.leftNode)
else: # 在右子树中更新节点值
if not node.rightNode:
node.rightNode = TreeNode(node.mid + 1, node.right)
self.__update_point(i, val, node.rightNode)
self.__pushup(node) # 向上更新节点的区间值

# 区间查询,查询区间为 [q_left, q_right] 的区间值
def query_interval(self, q_left, q_right):
return self.__query_interval(q_left, q_right, self.tree)

# 区间查询,在线段树的 [left, right] 区间范围中搜索区间为 [q_left, q_right] 的区间值
def __query_interval(self, q_left, q_right, node):
if node.left >= q_left and node.right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
return node.val # 直接返回节点值
if node.right < q_left or node.left > q_right: # 节点所在区间与 [q_left, q_right] 无关
return 0

self.__pushdown(node) # 向下更新节点所在区间的左右子节点的值和懒惰标记

res_left = 0 # 左子树查询结果
res_right = 0 # 右子树查询结果
if q_left <= node.mid: # 在左子树中查询
if not node.leftNode:
node.leftNode = TreeNode(node.left, node.mid)
res_left = self.__query_interval(q_left, q_right, node.leftNode)
if q_right > node.mid: # 在右子树中查询
if not node.rightNode:
node.rightNode = TreeNode(node.mid + 1, node.right)
res_right = self.__query_interval(q_left, q_right, node.rightNode)
return self.function(res_left, res_right) # 返回左右子树元素值的聚合计算结果

# 区间更新,将区间为 [q_left, q_right] 上的元素值修改为 val
def update_interval(self, q_left, q_right, val):
self.__update_interval(q_left, q_right, val, self.tree)

# 区间更新
def __update_interval(self, q_left, q_right, val, node):
if node.left >= q_left and node.right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
if node.lazy_tag:
node.lazy_tag += val # 将当前节点的延迟标记增加 val
else:
node.lazy_tag = val # 将当前节点的延迟标记增加 val
interval_size = (node.right - node.left + 1) # 当前节点所在区间大小
node.val += val * interval_size # 当前节点所在区间每个元素值增加 val
return
if node.right < q_left or node.left > q_right: # 节点所在区间与 [q_left, q_right] 无关
return 0

self.__pushdown(node) # 向下更新节点所在区间的左右子节点的值和懒惰标记

if q_left <= node.mid: # 在左子树中更新区间值
if not node.leftNode:
node.leftNode = TreeNode(node.left, node.mid)
self.__update_interval(q_left, q_right, val, node.leftNode)
if q_right > node.mid: # 在右子树中更新区间值
if not node.rightNode:
node.rightNode = TreeNode(node.mid + 1, node.right)
self.__update_interval(q_left, q_right, val, node.rightNode)

self.__pushup(node)

# 向下更新 node 节点所在区间的左右子节点的值和懒惰标记
def __pushdown(self, node):
lazy_tag = node.lazy_tag
if not node.lazy_tag:
return

if not node.leftNode:
node.leftNode = TreeNode(node.left, node.mid)
if not node.rightNode:
node.rightNode = TreeNode(node.mid + 1, node.right)

if node.leftNode.lazy_tag:
node.leftNode.lazy_tag += lazy_tag # 更新左子节点懒惰标记
else:
node.leftNode.lazy_tag = lazy_tag # 更新左子节点懒惰标记
left_size = (node.leftNode.right - node.leftNode.left + 1)
node.leftNode.val += lazy_tag * left_size # 左子节点每个元素值增加 lazy_tag

if node.rightNode.lazy_tag:
node.rightNode.lazy_tag += lazy_tag # 更新右子节点懒惰标记
else:
node.rightNode.lazy_tag = lazy_tag # 更新右子节点懒惰标记
right_size = (node.rightNode.right - node.rightNode.left + 1)
node.rightNode.val += lazy_tag * right_size # 右子节点每个元素值增加 lazy_tag

node.lazy_tag = None # 更新当前节点的懒惰标记

def get_nums(self, length):
nums = [0 for _ in range(length)]
for i in range(length):
nums[i] = self.query_interval(i, i)
return nums

0 comments on commit 7b8b728

Please sign in to comment.