From 6d680f7f28fb6535e4e958caab9eb210abc49697 Mon Sep 17 00:00:00 2001 From: jeffnvidia Date: Tue, 6 Aug 2024 13:46:30 +0300 Subject: [PATCH] simplify the code --- src/cloudai/systems/slurm/slurm_system.py | 97 ++++++++--------------- 1 file changed, 33 insertions(+), 64 deletions(-) diff --git a/src/cloudai/systems/slurm/slurm_system.py b/src/cloudai/systems/slurm/slurm_system.py index 27aac5b9..e1cfe7fb 100644 --- a/src/cloudai/systems/slurm/slurm_system.py +++ b/src/cloudai/systems/slurm/slurm_system.py @@ -335,37 +335,20 @@ def get_group_node_names(self, partition_name: str, group_name: str) -> List[str """ return [node.name for node in self.get_group_nodes(partition_name, group_name)] - def get_available_nodes_from_group_with_reservation( - self, partition_name: str, group_name: str, number_of_nodes: int - ) -> List[SlurmNode]: + def _get_reserved_nodes(self, partition_name: str, group_name: str) -> Dict[SlurmNodeState, List[SlurmNode]]: """ - Retrieve a specific number of potentially available nodes from a group within a partition. - - Prioritizes nodes by their current state, preferring idle nodes first, then completing nodes, and finally - allocated nodes, while excluding nodes that are down and allocated nodes to the current user. - If a reservation was queried, then cloudAI will take from the reserved nodes according to the reservation name. + Return the reserved nodes corresponding to the given reservation name. Args: partition_name (str): The name of the partition. group_name (str): The name of the group. - number_of_nodes (int): The number of nodes to retrieve. Returns: - List[SlurmNode]: Objects that are potentially available for use. - - Raises: - ValueError: If the partition or group is not found, or if the requested number of nodes exceeds the - available nodes. + Dict[str, str]: Names of nodes within the specified group and partition and reservation. """ - if partition_name not in self.groups: - raise ValueError(f"Partition '{partition_name}' not found.") - if group_name not in self.groups[partition_name]: - raise ValueError(f"Group '{group_name}' not found in partition '{partition_name}'.") - - self.update_node_states() - - # Group nodes by their states reservation_key = "--reservation " + if not self.extra_srun_args: + raise ValueError("extra_srun_args shouldn't be empty") reservation_name = self.extra_srun_args.split(reservation_key, 1)[1].split(" ", 1)[0] reservation_output = self.get_reservation(reservation_name) reserved_nodes = self.parse_reservation_output(reservation_output, reservation_name) @@ -375,30 +358,30 @@ def get_available_nodes_from_group_with_reservation( for node in self.groups[partition_name][group_name]: if node.state in grouped_nodes and node.name in reserved_nodes: grouped_nodes[node.state].append(node) - - # Allocate nodes based on priority: idle, then completing, then allocated - allocated_nodes = [] - for state in grouped_nodes: - while grouped_nodes[state] and len(allocated_nodes) < number_of_nodes: - allocated_nodes.append(grouped_nodes[state].pop(0)) - if len(allocated_nodes) < number_of_nodes: - raise ValueError( - "Requested number of nodes ({}) exceeds the number of " "available nodes in group '{}'.".format( - number_of_nodes, group_name - ) - ) + return grouped_nodes - # Log allocation details - logging.info( - "Allocated nodes from group '{}' in partition '{}': {}".format( - group_name, - partition_name, - [node.name for node in allocated_nodes], - ) - ) + def _get_available_nodes(self, partition_name: str, group_name: str): + """ + Return the available nodes sorted into idle and completing. - return allocated_nodes + Args: + partition_name (str): The name of the partition. + group_name (str): The name of the group. + + Returns: + Dict[str, str]: Names of nodes within the specified group and partition and reservation. + """ + grouped_nodes = { + SlurmNodeState.IDLE: [], + SlurmNodeState.COMPLETING: [], + } + + for node in self.groups[partition_name][group_name]: + if node.state in grouped_nodes: + grouped_nodes[node.state].append(node) + + return grouped_nodes def get_available_nodes_from_group( self, partition_name: str, group_name: str, number_of_nodes: int @@ -429,19 +412,10 @@ def get_available_nodes_from_group( self.update_node_states() - grouped_nodes = { - SlurmNodeState.IDLE: [], - SlurmNodeState.COMPLETING: [], - SlurmNodeState.ALLOCATED: [], - } - - for node in self.groups[partition_name][group_name]: - if node.state in grouped_nodes: - # Exclude nodes allocated to the current user - if node.state == SlurmNodeState.ALLOCATED and node.user == current_user: - continue - if node.state in grouped_nodes: - grouped_nodes[node.state].append(node) + if self.extra_srun_args and "reservation" in self.extra_srun_args: + grouped_nodes = self._get_reserved_nodes(partition_name, group_name) + else: + grouped_nodes = self._get_available_nodes(partition_name, group_name) # Allocate nodes based on priority: idle, then completing, then allocated allocated_nodes = [] @@ -714,8 +688,8 @@ def parse_reservation_output(self, reservation_output: str, reservation_name: st if reservation_name in reservation: nodes = reservation.split("Nodes=")[1].split(" ")[0] node_list = self.parse_node_list(nodes) - - return node_list + return node_list + raise ValueError("wrong reservation specified \n. Reservation should be in the form \"--reservation reservation_name\"") def convert_state_to_enum(self, state_str: str) -> SlurmNodeState: """ @@ -803,12 +777,7 @@ def parse_nodes(self, nodes: List[str]) -> List[str]: raise ValueError("Format should be partition:group:num_nodes") partition_name, group_name, num_nodes_str = parts num_nodes = int(num_nodes_str) - if self.extra_srun_args and "reservation" in self.extra_srun_args: - group_nodes = self.get_available_nodes_from_group_with_reservation( - partition_name, group_name, num_nodes - ) - else: - group_nodes = self.get_available_nodes_from_group(partition_name, group_name, num_nodes) + group_nodes = self.get_available_nodes_from_group(partition_name, group_name, num_nodes) parsed_nodes += [node.name for node in group_nodes] else: # Handle both individual node names and ranges