Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix updating nodes states #31

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/cloudai/parser/system_parser/slurm_system_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def safe_int(value):
raise ValueError("Partition data does not include a 'name' field.")

raw_nodes = partition_data.get("nodes", [])
node_names = set(SlurmSystem.parse_node_list(raw_nodes))
amaslenn marked this conversation as resolved.
Show resolved Hide resolved
node_names = set()
for group in raw_nodes:
node_names.update(set(SlurmSystem.parse_node_list(group)))

if not node_names:
raise ValueError(f"No valid nodes found in partition '{partition_name}'")
Expand Down Expand Up @@ -117,7 +119,9 @@ def safe_int(value):
raise ValueError("Group data does not include a 'name' field.")

raw_nodes = group_data.get("nodes", [])
group_node_names = set(SlurmSystem.parse_node_list(raw_nodes))
group_node_names = set()
for group in raw_nodes:
group_node_names.update(set(SlurmSystem.parse_node_list(group)))

group_nodes = []
for group_node_name in group_node_names:
Expand Down
84 changes: 39 additions & 45 deletions src/cloudai/schema/system/slurm/slurm_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,41 +49,37 @@ class SlurmSystem(System):
"""

@classmethod
def parse_node_list(cls, node_list: List[str]) -> List[str]:
def parse_node_list(cls, node_list: str) -> List[str]:
amaslenn marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand a list of node names (with ranges) into a flat list of individual node names, keeping leading zeroes.

Args:
node_list (List[str]): A list of node names, possibly including ranges.
node_list (str): A list of node names, possibly including ranges.

Returns:
List[str]: A flat list of expanded node names with preserved
zeroes.
"""
expanded_nodes = []
for node in node_list:
if "[" in node and "]" in node:
prefix, ranges = node.split("[")
ranges = ranges.strip("]")
range_elements = ranges.split(",")
for r in range_elements:
if "-" in r:
start_str, end_str = r.split("-")
else:
# For single nodes, treat the node itself as both start and end.
start_str = end_str = r

start, end = int(start_str), int(end_str)
max_length = max(len(start_str), len(end_str))

if "-" in r:
expanded_nodes.extend([f"{prefix}{str(i).zfill(max_length)}" for i in range(start, end + 1)])
else:
# For single nodes, append directly with appropriate padding.
expanded_nodes.append(f"{prefix}{start_str.zfill(max_length)}")
node_list = node_list.strip()
nodes = []
if not node_list:
return []
if "[" not in node_list:
return [node_list]
header, node_number = node_list.split("[")
node_number = node_number.replace("]", "")
ranges = node_number.split(",")
for r in ranges:
if "-" in r:
start_node, end_node = r.split("-")
number_of_digits = len(end_node)
nodes.extend(
[f"{header}{str(i).zfill(number_of_digits)}" for i in range(int(start_node), int(end_node) + 1)]
)
else:
expanded_nodes.append(node)
return expanded_nodes
nodes.append(f"{header}{r}")

return nodes

@classmethod
def format_node_list(cls, node_names: List[str]) -> str:
Expand Down Expand Up @@ -567,11 +563,8 @@ def parse_squeue_output(self, squeue_output: str) -> Dict[str, str]:

node_list_part, user = parts[0], "|".join(parts[1:])
# Handle cases where multiple node groups or ranges are specified
node_groups = node_list_part.split(",")
for node_group in node_groups:
# Process each node or range using parse_node_list
for node in self.parse_node_list([node_group.strip()]):
node_user_map[node] = user.strip()
for node in self.parse_node_list(node_list_part):
node_user_map[node] = user.strip()

return node_user_map

Expand All @@ -589,20 +582,21 @@ def parse_sinfo_output(self, sinfo_output: str, node_user_map: Dict[str, str]) -
parts = line.split()
partition, _, _, _, state, nodelist = parts[:6]
partition = partition.rstrip("*")

node_groups = nodelist.split(",")
for node_group in node_groups:
node_names = self.parse_node_list([node_group.strip()])
state_enum = self.convert_state_to_enum(state)

for node_name in node_names:
for part_name, nodes in self.partitions.items():
if part_name != partition:
continue
for node in nodes:
if node.name == node_name:
node.state = state_enum
node.user = node_user_map.get(node_name, "N/A")
node_names = self.parse_node_list(nodelist)

# Convert state to enum, handling states with suffixes
state_enum = self.convert_state_to_enum(state)

for node_name in node_names:
# Find the partition and node to update the state
for part_name, nodes in self.partitions.items():
if part_name != partition:
continue
for node in nodes:
if node.name == node_name:
node.state = state_enum
node.user = node_user_map.get(node_name, "N/A")
break

def convert_state_to_enum(self, state_str: str) -> SlurmNodeState:
"""
Expand Down Expand Up @@ -701,7 +695,7 @@ def parse_nodes(self, nodes: List[str]) -> List[str]:
else:
# Handle both individual node names and ranges
if self.is_node_in_system(node_spec) or "[" in node_spec:
expanded_nodes = self.parse_node_list([node_spec])
expanded_nodes = self.parse_node_list(node_spec)
parsed_nodes += expanded_nodes
else:
raise ValueError(f"Node '{node_spec}' not found.")
Expand Down
110 changes: 93 additions & 17 deletions tests/test_slurm_system.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from unittest.mock import patch

import pytest
Expand All @@ -7,16 +8,17 @@

@pytest.fixture
def slurm_system():
nodes = [
SlurmNode(name="nodeA001", partition="main", state=SlurmNodeState.UNKNOWN_STATE),
SlurmNode(name="nodeB001", partition="main", state=SlurmNodeState.UNKNOWN_STATE),
nodes = [SlurmNode(name=f"node-0{i}", partition="main", state=SlurmNodeState.UNKNOWN_STATE) for i in range(33, 65)]
backup_nodes = [
SlurmNode(name=f"node0{i}", partition="backup", state=SlurmNodeState.UNKNOWN_STATE) for i in range(1, 9)
]

system = SlurmSystem(
name="test_system",
install_path="/fake/path",
output_path="/fake/output",
default_partition="main",
partitions={"main": nodes},
partitions={"main": nodes, "backup": backup_nodes},
)
return system

Expand All @@ -29,7 +31,7 @@ def test_parse_squeue_output(slurm_system):


def test_parse_squeue_output_with_node_ranges_and_root_user(slurm_system):
squeue_output = "nodeA[001-008],nodeB[001-008]|root"
squeue_output = "nodeA[001-008]|root\nnodeB[001-008]|root"
user_map = slurm_system.parse_squeue_output(squeue_output)

expected_nodes = [f"nodeA{str(i).zfill(3)}" for i in range(1, 9)] + [f"nodeB{str(i).zfill(3)}" for i in range(1, 9)]
Expand All @@ -39,24 +41,98 @@ def test_parse_squeue_output_with_node_ranges_and_root_user(slurm_system):


def test_parse_sinfo_output(slurm_system):
sinfo_output = (
"PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\n"
"main up infinite 1 idle nodeA001\n"
"main up infinite 1 idle nodeB001"
)
node_user_map = {"nodeA001": "root", "nodeB001": "user"}
sinfo_output = """PARTITION AVAIL TIMELIMIT NODES STATE NODELIST
main up 3:00:00 1 inval node-036
main up 3:00:00 5 drain node-[045-046,059,061-062]
main up 3:00:00 2 resv node-[034-035]
main up 3:00:00 24 alloc node-[033,037-044,047-058,060,063-064]
backup up 12:00:00 8 idle node[01-08]
"""
node_user_map = {
"": "user1",
"node-033": "user2",
"node-037": "user3",
"node-038": "user3",
"node-039": "user3",
"node-040": "user3",
"node-041": "user3",
"node-042": "user4",
"node-043": "user4",
"node-044": "user4",
"node01": "user5",
"node02": "user5",
"node03": "user5",
"node04": "user5",
"node05": "user5",
"node06": "user5",
"node07": "user5",
"node08": "user5",
}
slurm_system.parse_sinfo_output(sinfo_output, node_user_map)
assert slurm_system.partitions["main"][0].state == SlurmNodeState.IDLE
assert slurm_system.partitions["main"][1].state == SlurmNodeState.IDLE
inval_nodes = set(["node-036"])
drain_nodes = set(["node-045", "node-046", "node-059", "node-061", "node-062"])
resv_nodes = set(["node-034", "node-035"])
for node in slurm_system.partitions["main"]:
if node.name in inval_nodes:
assert node.state == SlurmNodeState.INVALID_REGISTRATION
elif node.name in drain_nodes:
assert node.state == SlurmNodeState.DRAINED
elif node.name in resv_nodes:
assert node.state == SlurmNodeState.RESERVED
else:
print("node :", node)
assert node.state == SlurmNodeState.ALLOCATED
for node in slurm_system.partitions["backup"]:
assert node.state == SlurmNodeState.IDLE


@patch("cloudai.schema.system.SlurmSystem.get_squeue")
@patch("cloudai.schema.system.SlurmSystem.get_sinfo")
def test_update_node_states_with_mocked_outputs(mock_get_sinfo, mock_get_squeue, slurm_system):
mock_get_squeue.return_value = "nodeA001|root"
mock_get_sinfo.return_value = "PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\n" "main up infinite 1 idle nodeA001"
mock_get_squeue.return_value = "node-115|user1"
mock_get_sinfo.return_value = "PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\n" "main up infinite 1 idle node-115"
amaslenn marked this conversation as resolved.
Show resolved Hide resolved

slurm_system.update_node_states()
for node in slurm_system.partitions["main"]:
if node.name == "node-115":
assert node.state == SlurmNodeState.IDLE
assert node.user == "user1"

mock_get_squeue.return_value = "node01|root"
mock_get_sinfo.return_value = (
"PARTITION AVAIL TIMELIMIT NODES STATE NODELIST\n" "backup up infinite 1 allocated node01"
)

slurm_system.update_node_states()
for node in slurm_system.partitions["backup"]:
if node.name == "node01":
assert node.state == SlurmNodeState.ALLOCATED
assert node.user == "root"


assert slurm_system.partitions["main"][0].state == SlurmNodeState.IDLE
assert slurm_system.partitions["main"][0].user == "root"
@pytest.mark.parametrize(
"node_list,expected_parsed_node_list",
[
("node-[048-051]", ["node-048", "node-049", "node-050", "node-051"]),
("node-[055,114]", ["node-055", "node-114"]),
("", []),
("node-001", ["node-001"]),
("node[1-4]", ["node1", "node2", "node3", "node4"]),
(
"node-name[01-03,05-08,10]",
[
"node-name01",
"node-name02",
"node-name03",
"node-name05",
"node-name06",
"node-name07",
"node-name08",
"node-name10",
],
),
],
)
def test_parse_node_list(node_list: str, expected_parsed_node_list: List[str], slurm_system):
parsed_node_list = slurm_system.parse_node_list(node_list)
assert parsed_node_list == expected_parsed_node_list