diff --git a/src/cloudai/parser/system_parser/slurm_system_parser.py b/src/cloudai/parser/system_parser/slurm_system_parser.py index bc3345084..59aaa05ba 100644 --- a/src/cloudai/parser/system_parser/slurm_system_parser.py +++ b/src/cloudai/parser/system_parser/slurm_system_parser.py @@ -89,7 +89,9 @@ def safe_int(value): raise ValueError("Partition data does not include a 'name' field.") raw_nodes = partition_data.get("nodes", []) - node_names = set(SlurmSystem.parse_node_list(raw_nodes)) + node_names = set() + for group in raw_nodes: + node_names.update(set(SlurmSystem.parse_node_list(group))) if not node_names: raise ValueError(f"No valid nodes found in partition '{partition_name}'") @@ -117,7 +119,9 @@ def safe_int(value): raise ValueError("Group data does not include a 'name' field.") raw_nodes = group_data.get("nodes", []) - group_node_names = set(SlurmSystem.parse_node_list(raw_nodes)) + group_node_names = set() + for group in raw_nodes: + group_node_names.update(set(SlurmSystem.parse_node_list(group))) group_nodes = [] for group_node_name in group_node_names: diff --git a/src/cloudai/schema/system/slurm/slurm_system.py b/src/cloudai/schema/system/slurm/slurm_system.py index ff1250a27..179b70442 100644 --- a/src/cloudai/schema/system/slurm/slurm_system.py +++ b/src/cloudai/schema/system/slurm/slurm_system.py @@ -49,41 +49,37 @@ class SlurmSystem(System): """ @classmethod - def parse_node_list(cls, node_list: List[str]) -> List[str]: + def parse_node_list(cls, node_list: str) -> List[str]: """ Expand a list of node names (with ranges) into a flat list of individual node names, keeping leading zeroes. Args: - node_list (List[str]): A list of node names, possibly including ranges. + node_list (str): A list of node names, possibly including ranges. Returns: List[str]: A flat list of expanded node names with preserved zeroes. """ - expanded_nodes = [] - for node in node_list: - if "[" in node and "]" in node: - prefix, ranges = node.split("[") - ranges = ranges.strip("]") - range_elements = ranges.split(",") - for r in range_elements: - if "-" in r: - start_str, end_str = r.split("-") - else: - # For single nodes, treat the node itself as both start and end. - start_str = end_str = r - - start, end = int(start_str), int(end_str) - max_length = max(len(start_str), len(end_str)) - - if "-" in r: - expanded_nodes.extend([f"{prefix}{str(i).zfill(max_length)}" for i in range(start, end + 1)]) - else: - # For single nodes, append directly with appropriate padding. - expanded_nodes.append(f"{prefix}{start_str.zfill(max_length)}") + node_list = node_list.strip() + nodes = [] + if not node_list: + return [] + if "[" not in node_list: + return [node_list] + header, node_number = node_list.split("[") + node_number = node_number.replace("]", "") + ranges = node_number.split(",") + for r in ranges: + if "-" in r: + start_node, end_node = r.split("-") + number_of_digits = len(end_node) + nodes.extend( + [f"{header}{str(i).zfill(number_of_digits)}" for i in range(int(start_node), int(end_node) + 1)] + ) else: - expanded_nodes.append(node) - return expanded_nodes + nodes.append(f"{header}{r}") + + return nodes @classmethod def format_node_list(cls, node_names: List[str]) -> str: @@ -567,11 +563,8 @@ def parse_squeue_output(self, squeue_output: str) -> Dict[str, str]: node_list_part, user = parts[0], "|".join(parts[1:]) # Handle cases where multiple node groups or ranges are specified - node_groups = node_list_part.split(",") - for node_group in node_groups: - # Process each node or range using parse_node_list - for node in self.parse_node_list([node_group.strip()]): - node_user_map[node] = user.strip() + for node in self.parse_node_list(node_list_part): + node_user_map[node] = user.strip() return node_user_map @@ -589,20 +582,21 @@ def parse_sinfo_output(self, sinfo_output: str, node_user_map: Dict[str, str]) - parts = line.split() partition, _, _, _, state, nodelist = parts[:6] partition = partition.rstrip("*") - - node_groups = nodelist.split(",") - for node_group in node_groups: - node_names = self.parse_node_list([node_group.strip()]) - state_enum = self.convert_state_to_enum(state) - - for node_name in node_names: - for part_name, nodes in self.partitions.items(): - if part_name != partition: - continue - for node in nodes: - if node.name == node_name: - node.state = state_enum - node.user = node_user_map.get(node_name, "N/A") + node_names = self.parse_node_list(nodelist) + + # Convert state to enum, handling states with suffixes + state_enum = self.convert_state_to_enum(state) + + for node_name in node_names: + # Find the partition and node to update the state + for part_name, nodes in self.partitions.items(): + if part_name != partition: + continue + for node in nodes: + if node.name == node_name: + node.state = state_enum + node.user = node_user_map.get(node_name, "N/A") + break def convert_state_to_enum(self, state_str: str) -> SlurmNodeState: """ @@ -701,7 +695,7 @@ def parse_nodes(self, nodes: List[str]) -> List[str]: else: # Handle both individual node names and ranges if self.is_node_in_system(node_spec) or "[" in node_spec: - expanded_nodes = self.parse_node_list([node_spec]) + expanded_nodes = self.parse_node_list(node_spec) parsed_nodes += expanded_nodes else: raise ValueError(f"Node '{node_spec}' not found.") diff --git a/tests/test_slurm_system.py b/tests/test_slurm_system.py index 5b98e9b96..cdc8fda24 100644 --- a/tests/test_slurm_system.py +++ b/tests/test_slurm_system.py @@ -1,3 +1,4 @@ +from typing import List from unittest.mock import patch import pytest @@ -7,16 +8,17 @@ @pytest.fixture def slurm_system(): - nodes = [ - SlurmNode(name="nodeA001", partition="main", state=SlurmNodeState.UNKNOWN_STATE), - SlurmNode(name="nodeB001", partition="main", state=SlurmNodeState.UNKNOWN_STATE), + nodes = [SlurmNode(name=f"node-0{i}", partition="main", state=SlurmNodeState.UNKNOWN_STATE) for i in range(33, 65)] + backup_nodes = [ + SlurmNode(name=f"node0{i}", partition="backup", state=SlurmNodeState.UNKNOWN_STATE) for i in range(1, 9) ] + system = SlurmSystem( name="test_system", install_path="/fake/path", output_path="/fake/output", default_partition="main", - partitions={"main": nodes}, + partitions={"main": nodes, "backup": backup_nodes}, ) return system @@ -29,7 +31,7 @@ def test_parse_squeue_output(slurm_system): def test_parse_squeue_output_with_node_ranges_and_root_user(slurm_system): - squeue_output = "nodeA[001-008],nodeB[001-008]|root" + squeue_output = "nodeA[001-008]|root\nnodeB[001-008]|root" user_map = slurm_system.parse_squeue_output(squeue_output) expected_nodes = [f"nodeA{str(i).zfill(3)}" for i in range(1, 9)] + [f"nodeB{str(i).zfill(3)}" for i in range(1, 9)] @@ -39,24 +41,98 @@ def test_parse_squeue_output_with_node_ranges_and_root_user(slurm_system): def test_parse_sinfo_output(slurm_system): - sinfo_output = ( - "PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\n" - "main up infinite 1 idle nodeA001\n" - "main up infinite 1 idle nodeB001" - ) - node_user_map = {"nodeA001": "root", "nodeB001": "user"} + sinfo_output = """PARTITION AVAIL TIMELIMIT NODES STATE NODELIST + main up 3:00:00 1 inval node-036 + main up 3:00:00 5 drain node-[045-046,059,061-062] + main up 3:00:00 2 resv node-[034-035] + main up 3:00:00 24 alloc node-[033,037-044,047-058,060,063-064] + backup up 12:00:00 8 idle node[01-08] + """ + node_user_map = { + "": "user1", + "node-033": "user2", + "node-037": "user3", + "node-038": "user3", + "node-039": "user3", + "node-040": "user3", + "node-041": "user3", + "node-042": "user4", + "node-043": "user4", + "node-044": "user4", + "node01": "user5", + "node02": "user5", + "node03": "user5", + "node04": "user5", + "node05": "user5", + "node06": "user5", + "node07": "user5", + "node08": "user5", + } slurm_system.parse_sinfo_output(sinfo_output, node_user_map) - assert slurm_system.partitions["main"][0].state == SlurmNodeState.IDLE - assert slurm_system.partitions["main"][1].state == SlurmNodeState.IDLE + inval_nodes = set(["node-036"]) + drain_nodes = set(["node-045", "node-046", "node-059", "node-061", "node-062"]) + resv_nodes = set(["node-034", "node-035"]) + for node in slurm_system.partitions["main"]: + if node.name in inval_nodes: + assert node.state == SlurmNodeState.INVALID_REGISTRATION + elif node.name in drain_nodes: + assert node.state == SlurmNodeState.DRAINED + elif node.name in resv_nodes: + assert node.state == SlurmNodeState.RESERVED + else: + print("node :", node) + assert node.state == SlurmNodeState.ALLOCATED + for node in slurm_system.partitions["backup"]: + assert node.state == SlurmNodeState.IDLE @patch("cloudai.schema.system.SlurmSystem.get_squeue") @patch("cloudai.schema.system.SlurmSystem.get_sinfo") def test_update_node_states_with_mocked_outputs(mock_get_sinfo, mock_get_squeue, slurm_system): - mock_get_squeue.return_value = "nodeA001|root" - mock_get_sinfo.return_value = "PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\n" "main up infinite 1 idle nodeA001" + mock_get_squeue.return_value = "node-115|user1" + mock_get_sinfo.return_value = "PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\n" "main up infinite 1 idle node-115" slurm_system.update_node_states() + for node in slurm_system.partitions["main"]: + if node.name == "node-115": + assert node.state == SlurmNodeState.IDLE + assert node.user == "user1" + + mock_get_squeue.return_value = "node01|root" + mock_get_sinfo.return_value = ( + "PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\n" "backup up infinite 1 allocated node01" + ) + + slurm_system.update_node_states() + for node in slurm_system.partitions["backup"]: + if node.name == "node01": + assert node.state == SlurmNodeState.ALLOCATED + assert node.user == "root" + - assert slurm_system.partitions["main"][0].state == SlurmNodeState.IDLE - assert slurm_system.partitions["main"][0].user == "root" +@pytest.mark.parametrize( + "node_list,expected_parsed_node_list", + [ + ("node-[048-051]", ["node-048", "node-049", "node-050", "node-051"]), + ("node-[055,114]", ["node-055", "node-114"]), + ("", []), + ("node-001", ["node-001"]), + ("node[1-4]", ["node1", "node2", "node3", "node4"]), + ( + "node-name[01-03,05-08,10]", + [ + "node-name01", + "node-name02", + "node-name03", + "node-name05", + "node-name06", + "node-name07", + "node-name08", + "node-name10", + ], + ), + ], +) +def test_parse_node_list(node_list: str, expected_parsed_node_list: List[str], slurm_system): + parsed_node_list = slurm_system.parse_node_list(node_list) + assert parsed_node_list == expected_parsed_node_list