diff --git a/data_structures/binary_tree/red_black_tree.py b/data_structures/binary_tree/red_black_tree.py index 752db1e7026c..0e5c06d1c211 100644 --- a/data_structures/binary_tree/red_black_tree.py +++ b/data_structures/binary_tree/red_black_tree.py @@ -7,14 +7,34 @@ class RedBlackTree: """ A Red-Black tree, which is a self-balancing BST (binary search tree). - This tree has similar performance to AVL trees, but the balancing is - less strict, so it will perform faster for writing/deleting nodes - and slower for reading in the average case, though, because they're - both balanced binary search trees, both will get the same asymptotic - performance. - To read more about them, https://en.wikipedia.org/wiki/Red-black_tree - Unless otherwise specified, all asymptotic runtimes are specified in - terms of the size of the tree. + This tree has similar performance to AVL trees, but the balancing is + less strict, so it will perform faster for writing/deleting nodes + and slower for reading in the average case, though, because they're + both balanced binary search trees, both will get the same asymptotic + performance. + To read more about them, https://en.wikipedia.org/wiki/Red-black_tree + Unless otherwise specified, all asymptotic runtimes are specified in + terms of the size of the tree. + Examples: + >>> tree = RedBlackTree(0) + >>> tree = tree.insert(8).insert(-8).insert(4).insert(12) + >>> tree.check_color_properties() + True + >>> list(tree.inorder_traverse()) + [-8, 0, 4, 8, 12] + >>> tree.search(4).label + 4 + >>> tree.floor(5) + 4 + >>> tree.ceil(5) + 8 + >>> tree.get_min() + -8 + >>> tree.get_max() + 12 + >>> tree = tree.remove(4) + >>> 4 in tree + False """ def __init__( @@ -25,12 +45,21 @@ def __init__( left: RedBlackTree | None = None, right: RedBlackTree | None = None, ) -> None: - """Initialize a new Red-Black Tree node with the given values: - label: The value associated with this node - color: 0 if black, 1 if red - parent: The parent to this node - left: This node's left child - right: This node's right child + """Initialize a new Red-Black Tree node. + + Args: + label: The value associated with this node + color: 0 if black, 1 if red + parent: The parent to this node + left: This node's left child + right: This node's right child + + Examples: + >>> node = RedBlackTree(5) + >>> node.label + 5 + >>> node.color + 0 """ self.label = label self.parent = parent @@ -38,12 +67,23 @@ def __init__( self.right = right self.color = color - # Here are functions which are specific to red-black trees - def rotate_left(self) -> RedBlackTree: - """Rotate the subtree rooted at this node to the left and - returns the new root to this subtree. - Performing one rotation can be done in O(1). + """Rotate the subtree rooted at this node to the left. + + Returns: + The new root of the subtree + + Examples: + >>> root = RedBlackTree(2) + >>> root.right = RedBlackTree(4) + >>> root.right.left = RedBlackTree(3) + >>> new_root = root.rotate_left() + >>> new_root.label + 4 + >>> new_root.left.label + 2 + >>> new_root.left.right.label + 3 """ parent = self.parent right = self.right @@ -63,9 +103,22 @@ def rotate_left(self) -> RedBlackTree: return right def rotate_right(self) -> RedBlackTree: - """Rotate the subtree rooted at this node to the right and - returns the new root to this subtree. - Performing one rotation can be done in O(1). + """Rotate the subtree rooted at this node to the right. + + Returns: + The new root of the subtree + + Examples: + >>> root = RedBlackTree(4) + >>> root.left = RedBlackTree(2) + >>> root.left.right = RedBlackTree(3) + >>> new_root = root.rotate_right() + >>> new_root.label + 2 + >>> new_root.right.label + 4 + >>> new_root.right.left.label + 3 """ if self.left is None: return self @@ -85,13 +138,23 @@ def rotate_right(self) -> RedBlackTree: return left def insert(self, label: int) -> RedBlackTree: - """Inserts label into the subtree rooted at self, performs any - rotations necessary to maintain balance, and then returns the - new root to this subtree (likely self). - This is guaranteed to run in O(log(n)) time. + """Insert a label into the tree. + + Args: + label: The value to insert + + Returns: + The root of the tree + + Examples: + >>> tree = RedBlackTree() + >>> tree = tree.insert(5).insert(3).insert(7) + >>> list(tree.inorder_traverse()) + [3, 5, 7] + >>> tree.check_color_properties() + True """ if self.label is None: - # Only possible with an empty tree self.label = label return self if self.label == label: @@ -110,7 +173,7 @@ def insert(self, label: int) -> RedBlackTree: return self.parent or self def _insert_repair(self) -> None: - """Repair the coloring from inserting into a tree.""" + """Repair the coloring after insertion.""" if self.parent is None: # This node is the root, so it just needs to be black self.color = 0 @@ -148,35 +211,43 @@ def _insert_repair(self) -> None: self.grandparent._insert_repair() def remove(self, label: int) -> RedBlackTree: - """Remove label from this tree.""" + """Remove a label from the tree. + + Args: + label: The value to remove + + Returns: + The root of the tree + + Examples: + >>> tree = RedBlackTree(5) + >>> tree = tree.insert(3).insert(7) + >>> tree = tree.remove(3) + >>> 3 in tree + False + >>> tree.check_color_properties() + True + """ if self.label == label: if self.left and self.right: # It's easier to balance a node with at most one child, # so we replace this node with the greatest one less than # it and remove that. + value = self.left.get_max() if value is not None: self.label = value self.left.remove(value) else: - # This node has at most one non-None child, so we don't - # need to replace child = self.left or self.right if self.color == 1: - # This node is red, and its child is black - # The only way this happens to a node with one child - # is if both children are None leaves. - # We can just remove this node and call it a day. if self.parent: if self.is_left(): self.parent.left = None else: self.parent.right = None - # The node is black elif child is None: - # This node and its child are black if self.parent is None: - # The tree is now empty return RedBlackTree(None) else: self._remove_repair() @@ -186,8 +257,6 @@ def remove(self, label: int) -> RedBlackTree: self.parent.right = None self.parent = None else: - # This node is black and its child is red - # Move the child node here and make it black self.label = child.label self.left = child.left self.right = child.right @@ -203,7 +272,7 @@ def remove(self, label: int) -> RedBlackTree: return self.parent or self def _remove_repair(self) -> None: - """Repair the coloring of the tree that may have been messed up.""" + """Repair the coloring after removal.""" if ( self.parent is None or self.sibling is None @@ -276,42 +345,28 @@ def _remove_repair(self) -> None: self.parent.sibling.color = 0 def check_color_properties(self) -> bool: - """Check the coloring of the tree, and return True iff the tree - is colored in a way which matches these five properties: - (wording stolen from wikipedia article) - 1. Each node is either red or black. - 2. The root node is black. - 3. All leaves are black. - 4. If a node is red, then both its children are black. - 5. Every path from any node to all of its descendent NIL nodes - has the same number of black nodes. - This function runs in O(n) time, because properties 4 and 5 take - that long to check. """ - # I assume property 1 to hold because there is nothing that can - # make the color be anything other than 0 or 1. - # Property 2 - if self.color: - # The root was red - print("Property 2") + Verify that all Red-Black Tree properties are satisfied: + 1. Root node is black + 2. No two consecutive red nodes + 3. All paths have same black height + + Returns: + True if all properties are satisfied, False otherwise + """ + # Property 1: Root must be black + if self.parent is None and self.color != 0: return False - # Property 3 does not need to be checked, because None is assumed - # to be black and is all the leaves. - # Property 4 + + # Property 2: No two consecutive red nodes if not self.check_coloring(): - print("Property 4") - return False - # Property 5 - if self.black_height() is None: - print("Property 5") return False - # All properties were met - return True + + # Property 3: All paths have same black height + return self.black_height() is not None def check_coloring(self) -> bool: - """A helper function to recursively check Property 4 of a - Red-Black Tree. See check_color_properties for more info. - """ + """Check if the tree satisfies Red-Black property 4.""" if self.color == 1 and 1 in (color(self.left), color(self.right)): return False if self.left and not self.left.check_coloring(): @@ -319,38 +374,65 @@ def check_coloring(self) -> bool: return not (self.right and not self.right.check_coloring()) def black_height(self) -> int | None: - """Returns the number of black nodes from this node to the - leaves of the tree, or None if there isn't one such value (the - tree is color incorrectly). """ - if self is None or self.left is None or self.right is None: - # If we're already at a leaf, there is no path - return 1 - left = RedBlackTree.black_height(self.left) - right = RedBlackTree.black_height(self.right) - if left is None or right is None: - # There are issues with coloring below children nodes - return None - if left != right: - # The two children have unequal depths + Calculate the black height of the tree and verify consistency + - Black height = number of black nodes from current node to any leaf + - Returns None if any path has different black height + + Returns: + Black height if consistent, None otherwise + """ + # Leaf node case (both children are None) + if self.left is None and self.right is None: + # Count: current node (if black) + leaf (black) + return 1 + (1 - self.color) # 2 if black, 1 if red + + # Get black heights from both subtrees + left_bh = self.left.black_height() if self.left else 1 + right_bh = self.right.black_height() if self.right else 1 + + # Validate consistency + if left_bh is None or right_bh is None or left_bh != right_bh: return None - # Return the black depth of children, plus one if this node is - # black - return left + (1 - self.color) - # Here are functions which are general to all binary search trees + # Add current node's contribution (1 if black, 0 if red) + return left_bh + (1 - self.color) def __contains__(self, label: int) -> bool: - """Search through the tree for label, returning True iff it is - found somewhere in the tree. - Guaranteed to run in O(log(n)) time. + """Check if the tree contains a label. + + Args: + label: The value to check + + Returns: + True if the label is in the tree, False otherwise + + Examples: + >>> tree = RedBlackTree(5) + >>> tree = tree.insert(3) + >>> 3 in tree + True + >>> 4 in tree + False """ return self.search(label) is not None def search(self, label: int) -> RedBlackTree | None: - """Search through the tree for label, returning its node if - it's found, and None otherwise. - This method is guaranteed to run in O(log(n)) time. + """Search for a label in the tree. + + Args: + label: The value to search for + + Returns: + The node containing the label, or None if not found + + Examples: + >>> tree = RedBlackTree(5) + >>> node = tree.search(5) + >>> node.label + 5 + >>> tree.search(10) is None + True """ if self.label == label: return self @@ -365,8 +447,22 @@ def search(self, label: int) -> RedBlackTree | None: return self.left.search(label) def floor(self, label: int) -> int | None: - """Returns the largest element in this tree which is at most label. - This method is guaranteed to run in O(log(n)) time.""" + """Find the largest element <= label. + + Args: + label: The value to find the floor of + + Returns: + The floor value, or None if no such element exists + + Examples: + >>> tree = RedBlackTree(5) + >>> tree = tree.insert(3).insert(7) + >>> tree.floor(6) + 5 + >>> tree.floor(2) is None + True + """ if self.label == label: return self.label elif self.label is not None and self.label > label: @@ -382,8 +478,21 @@ def floor(self, label: int) -> int | None: return self.label def ceil(self, label: int) -> int | None: - """Returns the smallest element in this tree which is at least label. - This method is guaranteed to run in O(log(n)) time. + """Find the smallest element >= label. + + Args: + label: The value to find the ceil of + + Returns: + The ceil value, or None if no such element exists + + Examples: + >>> tree = RedBlackTree(5) + >>> tree = tree.insert(3).insert(7) + >>> tree.ceil(6) + 7 + >>> tree.ceil(8) is None + True """ if self.label == label: return self.label @@ -400,28 +509,42 @@ def ceil(self, label: int) -> int | None: return self.label def get_max(self) -> int | None: - """Returns the largest element in this tree. - This method is guaranteed to run in O(log(n)) time. + """Get the maximum element in the tree. + + Returns: + The maximum value, or None if the tree is empty + + Examples: + >>> tree = RedBlackTree(5) + >>> tree = tree.insert(3).insert(7) + >>> tree.get_max() + 7 """ if self.right: - # Go as far right as possible return self.right.get_max() else: return self.label def get_min(self) -> int | None: - """Returns the smallest element in this tree. - This method is guaranteed to run in O(log(n)) time. + """Get the minimum element in the tree. + + Returns: + The minimum value, or None if the tree is empty + + Examples: + >>> tree = RedBlackTree(5) + >>> tree = tree.insert(3).insert(7) + >>> tree.get_min() + 3 """ if self.left: - # Go as far left as possible return self.left.get_min() else: return self.label @property def grandparent(self) -> RedBlackTree | None: - """Get the current node's grandparent, or None if it doesn't exist.""" + """Get the grandparent of this node.""" if self.parent is None: return None else: @@ -429,7 +552,7 @@ def grandparent(self) -> RedBlackTree | None: @property def sibling(self) -> RedBlackTree | None: - """Get the current node's sibling, or None if it doesn't exist.""" + """Get the sibling of this node.""" if self.parent is None: return None elif self.parent.left is self: @@ -438,23 +561,29 @@ def sibling(self) -> RedBlackTree | None: return self.parent.left def is_left(self) -> bool: - """Returns true iff this node is the left child of its parent.""" + """Check if this node is the left child of its parent.""" if self.parent is None: return False return self.parent.left is self def is_right(self) -> bool: - """Returns true iff this node is the right child of its parent.""" + """Check if this node is the right child of its parent.""" if self.parent is None: return False return self.parent.right is self def __bool__(self) -> bool: + """Return True if the tree is not empty.""" return True def __len__(self) -> int: - """ - Return the number of nodes in this tree. + """Return the number of nodes in the tree. + + Examples: + >>> tree = RedBlackTree(5) + >>> tree = tree.insert(3).insert(7) + >>> len(tree) + 3 """ ln = 1 if self.left: @@ -464,6 +593,18 @@ def __len__(self) -> int: return ln def preorder_traverse(self) -> Iterator[int | None]: + """Traverse the tree in pre-order. + + Yields: + The values in pre-order + + Examples: + >>> tree = RedBlackTree(2) + >>> tree.left = RedBlackTree(1) + >>> tree.right = RedBlackTree(3) + >>> list(tree.preorder_traverse()) + [2, 1, 3] + """ yield self.label if self.left: yield from self.left.preorder_traverse() @@ -471,6 +612,18 @@ def preorder_traverse(self) -> Iterator[int | None]: yield from self.right.preorder_traverse() def inorder_traverse(self) -> Iterator[int | None]: + """Traverse the tree in in-order. + + Yields: + The values in in-order + + Examples: + >>> tree = RedBlackTree(2) + >>> tree.left = RedBlackTree(1) + >>> tree.right = RedBlackTree(3) + >>> list(tree.inorder_traverse()) + [1, 2, 3] + """ if self.left: yield from self.left.inorder_traverse() yield self.label @@ -478,6 +631,18 @@ def inorder_traverse(self) -> Iterator[int | None]: yield from self.right.inorder_traverse() def postorder_traverse(self) -> Iterator[int | None]: + """Traverse the tree in post-order. + + Yields: + The values in post-order + + Examples: + >>> tree = RedBlackTree(2) + >>> tree.left = RedBlackTree(1) + >>> tree.right = RedBlackTree(3) + >>> list(tree.postorder_traverse()) + [1, 3, 2] + """ if self.left: yield from self.left.postorder_traverse() if self.right: @@ -485,6 +650,7 @@ def postorder_traverse(self) -> Iterator[int | None]: yield self.label def __repr__(self) -> str: + """Return a string representation of the tree.""" from pprint import pformat if self.left is None and self.right is None: @@ -508,6 +674,10 @@ def __eq__(self, other: object) -> bool: else: return False + def __hash__(self): + """Return a hash value for the node.""" + return hash((self.label, self.color)) + def color(node: RedBlackTree | None) -> int: """Returns the color of a node, allowing for None leaves.""" @@ -525,7 +695,6 @@ def color(node: RedBlackTree | None) -> int: def test_rotations() -> bool: """Test that the rotate_left and rotate_right functions work.""" - # Make a tree to test on tree = RedBlackTree(0) tree.left = RedBlackTree(-10, parent=tree) tree.right = RedBlackTree(10, parent=tree) @@ -533,7 +702,6 @@ def test_rotations() -> bool: tree.left.right = RedBlackTree(-5, parent=tree.left) tree.right.left = RedBlackTree(5, parent=tree.right) tree.right.right = RedBlackTree(20, parent=tree.right) - # Make the right rotation left_rot = RedBlackTree(10) left_rot.left = RedBlackTree(0, parent=left_rot) left_rot.left.left = RedBlackTree(-10, parent=left_rot.left) @@ -546,7 +714,6 @@ def test_rotations() -> bool: return False tree = tree.rotate_right() tree = tree.rotate_right() - # Make the left rotation right_rot = RedBlackTree(-10) right_rot.left = RedBlackTree(-20, parent=right_rot) right_rot.right = RedBlackTree(0, parent=right_rot) @@ -598,16 +765,12 @@ def test_insert_and_search() -> bool: tree.insert(10) tree.insert(11) if any(i in tree for i in (5, -6, -10, 13)): - # Found something not in there return False - # Find all these things in there return all(i in tree for i in (11, 12, -8, 0)) def test_insert_delete() -> bool: - """Test the insert() and delete() method of the tree, verifying the - insertion and removal of elements, and the balancing of the tree. - """ + """Test the insert() and delete() method of the tree.""" tree = RedBlackTree(0) tree = tree.insert(-12) tree = tree.insert(8) @@ -699,13 +862,21 @@ def main() -> None: """ >>> pytests() """ + import doctest + + failures, _ = doctest.testmod() + if failures == 0: + print("All doctests passed!") + else: + print(f"{failures} doctests failed!") + print_results("Rotating right and left", test_rotations()) print_results("Inserting", test_insert()) print_results("Searching", test_insert_and_search()) print_results("Deleting", test_insert_delete()) print_results("Floor and ceil", test_floor_ceil()) print_results("Tree traversal", test_tree_traversal()) - print_results("Tree traversal", test_tree_chaining()) + print_results("Tree chaining", test_tree_chaining()) print("Testing tree balancing...") print("This should only be a few seconds.") test_insertion_speed()