Skip to content

Commit

Permalink
fix updating nodes states
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffnvidia committed May 27, 2024
1 parent 1e11c49 commit 5dcf3d6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 45 deletions.
8 changes: 6 additions & 2 deletions src/cloudai/parser/system_parser/slurm_system_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def safe_int(value):
raise ValueError("Partition data does not include a 'name' field.")

raw_nodes = partition_data.get("nodes", [])
node_names = set(SlurmSystem.parse_node_list(raw_nodes))
node_names = set()
for group in raw_nodes:
node_names.update(set(SlurmSystem.parse_node_list(group)))

if not node_names:
raise ValueError(f"No valid nodes found in partition '{partition_name}'")
Expand Down Expand Up @@ -117,7 +119,9 @@ def safe_int(value):
raise ValueError("Group data does not include a 'name' field.")

raw_nodes = group_data.get("nodes", [])
group_node_names = set(SlurmSystem.parse_node_list(raw_nodes))
group_node_names = set()
for group in raw_nodes:
group_node_names.update(set(SlurmSystem.parse_node_list(group)))

group_nodes = []
for group_node_name in group_node_names:
Expand Down
79 changes: 36 additions & 43 deletions src/cloudai/schema/system/slurm/slurm_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SlurmSystem(System):
"""

@classmethod
def parse_node_list(cls, node_list: List[str]) -> List[str]:
def parse_node_list(cls, node_list: str) -> List[str]:
"""
Expand a list of node names (with ranges) into a flat list of individual node names, keeping leading zeroes.
Expand All @@ -60,30 +60,23 @@ def parse_node_list(cls, node_list: List[str]) -> List[str]:
List[str]: A flat list of expanded node names with preserved
zeroes.
"""
expanded_nodes = []
for node in node_list:
if "[" in node and "]" in node:
prefix, ranges = node.split("[")
ranges = ranges.strip("]")
range_elements = ranges.split(",")
for r in range_elements:
if "-" in r:
start_str, end_str = r.split("-")
else:
# For single nodes, treat the node itself as both start and end.
start_str = end_str = r

start, end = int(start_str), int(end_str)
max_length = max(len(start_str), len(end_str))

if "-" in r:
expanded_nodes.extend([f"{prefix}{str(i).zfill(max_length)}" for i in range(start, end + 1)])
else:
# For single nodes, append directly with appropriate padding.
expanded_nodes.append(f"{prefix}{start_str.zfill(max_length)}")
nodes = []
if "[" not in node_list:
return [node_list]
header, node_number = node_list.split("[")
node_number = node_number.replace("]", "")
ranges = node_number.split(",")
for r in ranges:
if "-" in r:
start_node, end_node = r.split("-")
number_of_digits = len(end_node)
nodes.extend(
[f"{header}{str(i).zfill(number_of_digits)}" for i in range(int(start_node), int(end_node) + 1)]
)
else:
expanded_nodes.append(node)
return expanded_nodes
nodes.append(f"{header}{r}")

return nodes

@classmethod
def format_node_list(cls, node_names: List[str]) -> str:
Expand Down Expand Up @@ -567,10 +560,8 @@ def parse_squeue_output(self, squeue_output: str) -> Dict[str, str]:

node_list_part, user = parts[0], "|".join(parts[1:])
# Handle cases where multiple node groups or ranges are specified
node_groups = node_list_part.split(",")
for node_group in node_groups:
# Process each node or range using parse_node_list
for node in self.parse_node_list([node_group.strip()]):
if node_list_part:
for node in self.parse_node_list(node_list_part.strip()):
node_user_map[node] = user.strip()

return node_user_map
Expand All @@ -589,20 +580,22 @@ def parse_sinfo_output(self, sinfo_output: str, node_user_map: Dict[str, str]) -
parts = line.split()
partition, _, _, _, state, nodelist = parts[:6]
partition = partition.rstrip("*")

node_groups = nodelist.split(",")
for node_group in node_groups:
node_names = self.parse_node_list([node_group.strip()])
state_enum = self.convert_state_to_enum(state)

for node_name in node_names:
for part_name, nodes in self.partitions.items():
if part_name != partition:
continue
for node in nodes:
if node.name == node_name:
node.state = state_enum
node.user = node_user_map.get(node_name, "N/A")
node_names = self.parse_node_list(nodelist)

# Convert state to enum, handling states with suffixes
state_enum = self.convert_state_to_enum(state)

nodelist.split(",")
for node_name in node_names:
# Find the partition and node to update the state
for part_name, nodes in self.partitions.items():
if part_name != partition:
continue
for node in nodes:
if node.name == node_name:
node.state = state_enum
node.user = node_user_map.get(node_name, "N/A")
break

def convert_state_to_enum(self, state_str: str) -> SlurmNodeState:
"""
Expand Down Expand Up @@ -701,7 +694,7 @@ def parse_nodes(self, nodes: List[str]) -> List[str]:
else:
# Handle both individual node names and ranges
if self.is_node_in_system(node_spec) or "[" in node_spec:
expanded_nodes = self.parse_node_list([node_spec])
expanded_nodes = self.parse_node_list(node_spec)
parsed_nodes += expanded_nodes
else:
raise ValueError(f"Node '{node_spec}' not found.")
Expand Down

0 comments on commit 5dcf3d6

Please sign in to comment.