@@ -518,8 +518,8 @@ def __init__(self, http_proxy_address: str = None):
518518class _BaseNetworking (Resource ):
519519 """Represent the networking configuration shared by head node and compute node."""
520520
521- def __init__ (self , security_groups : List [str ] = None , additional_security_groups : List [str ] = None ):
522- super ().__init__ ()
521+ def __init__ (self , security_groups : List [str ] = None , additional_security_groups : List [str ] = None , ** kwargs ):
522+ super ().__init__ (** kwargs )
523523 self .security_groups = Resource .init_param (security_groups )
524524 self .additional_security_groups = Resource .init_param (additional_security_groups )
525525
@@ -548,17 +548,35 @@ def availability_zone(self):
548548
549549
550550class PlacementGroup (Resource ):
551- """Represent the placement group for the Queue networking."""
551+ """Represent the placement group for networking."""
552552
553- def __init__ (self , enabled : bool = None , name : str = None , id : str = None ):
554- super ().__init__ ()
555- self .enabled = Resource .init_param (enabled , default = False )
553+ def __init__ (self , enabled : bool = None , name : str = None , id : str = None , ** kwargs ):
554+ super ().__init__ (** kwargs )
555+ self .enabled = Resource .init_param (enabled )
556556 self .name = Resource .init_param (name )
557557 self .id = Resource .init_param (id ) # Duplicate of name
558558
559559 def _register_validators (self ):
560560 self ._register_validator (PlacementGroupNamingValidator , placement_group = self )
561561
562+ @property
563+ def is_enabled_and_unassigned (self ) -> bool :
564+ """Check if the PlacementGroup is enabled without a name or id."""
565+ return not (self .id or self .name ) and self .enabled
566+
567+ @property
568+ def assignment (self ) -> str :
569+ """Check if the placement group has a name or id and get it, preferring the name if it exists."""
570+ return self .name or self .id
571+
572+
573+ class SlurmComputeResourceNetworking (Resource ):
574+ """Represent the networking configuration for the compute resource."""
575+
576+ def __init__ (self , placement_group : PlacementGroup = None , ** kwargs ):
577+ super ().__init__ (** kwargs )
578+ self .placement_group = placement_group or PlacementGroup (implied = True )
579+
562580
563581class _QueueNetworking (_BaseNetworking ):
564582 """Represent the networking configuration for the Queue."""
@@ -574,7 +592,7 @@ class SlurmQueueNetworking(_QueueNetworking):
574592
575593 def __init__ (self , placement_group : PlacementGroup = None , proxy : Proxy = None , ** kwargs ):
576594 super ().__init__ (** kwargs )
577- self .placement_group = placement_group
595+ self .placement_group = placement_group or PlacementGroup ( implied = True )
578596 self .proxy = proxy
579597
580598
@@ -1623,6 +1641,7 @@ def __init__(
16231641 disable_simultaneous_multithreading : bool = None ,
16241642 schedulable_memory : int = None ,
16251643 capacity_reservation_target : CapacityReservationTarget = None ,
1644+ networking : SlurmComputeResourceNetworking = None ,
16261645 ** kwargs ,
16271646 ):
16281647 super ().__init__ (** kwargs )
@@ -1637,6 +1656,7 @@ def __init__(
16371656 self .capacity_reservation_target = capacity_reservation_target
16381657 self ._instance_types_with_instance_storage = []
16391658 self ._instance_type_info_map = {}
1659+ self .networking = networking or SlurmComputeResourceNetworking (implied = True )
16401660
16411661 @staticmethod
16421662 def fetch_instance_type_info (instance_type ) -> InstanceTypeInfo :
@@ -1792,11 +1812,25 @@ def disable_simultaneous_multithreading_manually(self) -> bool:
17921812 return self .disable_simultaneous_multithreading and self .instance_type_info .default_threads_per_core () > 1
17931813
17941814
1815+ class SchedulerPluginComputeResource (SlurmComputeResource ):
1816+ """Represent the Scheduler Plugin Compute Resource."""
1817+
1818+ def __init__ (
1819+ self ,
1820+ custom_settings : Dict = None ,
1821+ ** kwargs ,
1822+ ):
1823+ super ().__init__ (** kwargs )
1824+ self .custom_settings = custom_settings
1825+
1826+
17951827class _CommonQueue (BaseQueue ):
17961828 """Represent the Common Queue resource between Slurm and Scheduler Plugin."""
17971829
17981830 def __init__ (
17991831 self ,
1832+ compute_resources : List [Union [_BaseSlurmComputeResource , SchedulerPluginComputeResource ]],
1833+ networking : Union [SlurmQueueNetworking , SchedulerPluginQueueNetworking ],
18001834 compute_settings : ComputeSettings = None ,
18011835 custom_actions : CustomActions = None ,
18021836 iam : Iam = None ,
@@ -1810,6 +1844,8 @@ def __init__(
18101844 self .iam = iam or Iam (implied = True )
18111845 self .image = image
18121846 self .capacity_reservation_target = capacity_reservation_target
1847+ self .compute_resources = compute_resources
1848+ self .networking = networking
18131849
18141850 @property
18151851 def instance_role (self ):
@@ -1829,6 +1865,43 @@ def queue_ami(self):
18291865 else :
18301866 return None
18311867
1868+ def get_managed_placement_group_keys (self ) -> List [str ]:
1869+ managed_placement_group_keys = []
1870+ for resource in self .compute_resources :
1871+ chosen_pg = (
1872+ resource .networking .placement_group
1873+ if not resource .networking .placement_group .implied
1874+ else self .networking .placement_group
1875+ )
1876+ if chosen_pg .is_enabled_and_unassigned :
1877+ managed_placement_group_keys .append (f"{ self .name } -{ resource .name } " )
1878+ return managed_placement_group_keys
1879+
1880+ def get_placement_group_key_for_compute_resource (
1881+ self , compute_resource : Union [_BaseSlurmComputeResource , SchedulerPluginComputeResource ]
1882+ ) -> (str , bool ):
1883+ # prefer compute level groups over queue level groups
1884+ placement_group_key , managed = None , None
1885+ cr_pg = compute_resource .networking .placement_group
1886+ if cr_pg .assignment :
1887+ placement_group_key , managed = cr_pg .assignment , False
1888+ elif cr_pg .enabled :
1889+ placement_group_key , managed = f"{ self .name } -{ compute_resource .name } " , True
1890+ elif cr_pg .enabled is False :
1891+ placement_group_key , managed = None , False
1892+ elif self .networking .placement_group .assignment :
1893+ placement_group_key , managed = self .networking .placement_group .assignment , False
1894+ elif self .networking .placement_group .enabled :
1895+ placement_group_key , managed = f"{ self .name } -{ compute_resource .name } " , True
1896+ return placement_group_key , managed
1897+
1898+ def is_placement_group_disabled_for_compute_resource (self , compute_resource_pg_enabled : bool ) -> bool :
1899+ return (
1900+ compute_resource_pg_enabled is False
1901+ or self .networking .placement_group .enabled is False
1902+ and compute_resource_pg_enabled is None
1903+ )
1904+
18321905
18331906class AllocationStrategy (Enum ):
18341907 """Define supported allocation strategies."""
@@ -1842,15 +1915,13 @@ class SlurmQueue(_CommonQueue):
18421915
18431916 def __init__ (
18441917 self ,
1845- compute_resources : List [_BaseSlurmComputeResource ],
1846- networking : SlurmQueueNetworking ,
18471918 allocation_strategy : str = None ,
18481919 ** kwargs ,
18491920 ):
18501921 super ().__init__ (** kwargs )
1851- self . compute_resources = compute_resources
1852- self . networking = networking
1853- if any ( isinstance ( compute_resource , SlurmFlexibleComputeResource ) for compute_resource in compute_resources ):
1922+ if any (
1923+ isinstance ( compute_resource , SlurmFlexibleComputeResource ) for compute_resource in self . compute_resources
1924+ ):
18541925 self .allocation_strategy = (
18551926 AllocationStrategy [to_snake_case (allocation_strategy ).upper ()]
18561927 if allocation_strategy
@@ -1896,7 +1967,10 @@ def _register_validators(self):
18961967 self ._register_validator (
18971968 EfaPlacementGroupValidator ,
18981969 efa_enabled = compute_resource .efa .enabled ,
1899- placement_group = self .networking .placement_group ,
1970+ placement_group_key = self .get_placement_group_key_for_compute_resource (compute_resource )[0 ],
1971+ placement_group_disabled = self .is_placement_group_disabled_for_compute_resource (
1972+ compute_resource .networking .placement_group .enabled
1973+ ),
19001974 )
19011975 for instance_type in compute_resource .instance_types :
19021976 self ._register_validator (
@@ -1967,31 +2041,15 @@ def _register_validators(self):
19672041 )
19682042
19692043
1970- class SchedulerPluginComputeResource (SlurmComputeResource ):
1971- """Represent the Scheduler Plugin Compute Resource."""
1972-
1973- def __init__ (
1974- self ,
1975- custom_settings : Dict = None ,
1976- ** kwargs ,
1977- ):
1978- super ().__init__ (** kwargs )
1979- self .custom_settings = custom_settings
1980-
1981-
19822044class SchedulerPluginQueue (_CommonQueue ):
19832045 """Represent the Scheduler Plugin queue."""
19842046
19852047 def __init__ (
19862048 self ,
1987- compute_resources : List [SchedulerPluginComputeResource ],
1988- networking : SchedulerPluginQueueNetworking ,
19892049 custom_settings : Dict = None ,
19902050 ** kwargs ,
19912051 ):
19922052 super ().__init__ (** kwargs )
1993- self .compute_resources = compute_resources
1994- self .networking = networking
19952053 self .custom_settings = custom_settings
19962054
19972055 def _register_validators (self ):
@@ -2014,7 +2072,10 @@ def _register_validators(self):
20142072 self ._register_validator (
20152073 EfaPlacementGroupValidator ,
20162074 efa_enabled = compute_resource .efa .enabled ,
2017- placement_group = self .networking .placement_group ,
2075+ placement_group_key = self .get_placement_group_key_for_compute_resource (compute_resource )[0 ],
2076+ placement_group_disabled = self .is_placement_group_disabled_for_compute_resource (
2077+ compute_resource .networking .placement_group .enabled
2078+ ),
20182079 )
20192080
20202081 @property
0 commit comments