Skip to content

Commit 0dfd219

Browse files
temporary commit
Signed-off-by: Hanwen <[email protected]>
1 parent b32710d commit 0dfd219

File tree

10 files changed

+290
-123
lines changed

10 files changed

+290
-123
lines changed

cli/src/pcluster/aws/aws_resources.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self, stack_data: dict):
3737
self._params = self._stack_data.get("Parameters", [])
3838
self.tags = self._stack_data.get("Tags", [])
3939
self.outputs = self._stack_data.get("Outputs", [])
40+
self.__resources = None
4041

4142
@property
4243
def id(self):
@@ -48,6 +49,15 @@ def name(self):
4849
"""Return the name of the stack."""
4950
return self._stack_data.get("StackName")
5051

52+
@property
53+
def resources(self):
54+
"""Return the resources of the stack."""
55+
if not self.__resources:
56+
from pcluster.aws.aws_api import AWSApi
57+
58+
self.__resources = AWSApi.instance().cfn.describe_stack_resources(self.name)
59+
return self.__resources
60+
5161
@property
5262
def status(self):
5363
"""Return the status of the stack."""
@@ -90,6 +100,10 @@ def _get_param(self, key_name):
90100
param_value = next((par["ParameterValue"] for par in self._params if par["ParameterKey"] == key_name), None)
91101
return None if param_value is None else param_value.strip()
92102

103+
def get_resource(self, resource_logical_id: str):
104+
"""Return the resource information."""
105+
return self.resources.get(resource_logical_id)
106+
93107

94108
class InstanceInfo:
95109
"""Object to store Instance information, initialized with a describe_instances call."""

cli/src/pcluster/aws/cfn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ def describe_stack_resource(self, stack_name: str, logic_resource_id: str):
147147
function_name="describe_stack_resource", message=f"No resource {logic_resource_id} found."
148148
)
149149

150+
@AWSExceptionHandler.handle_client_exception
151+
def describe_stack_resources(self, stack_name: str):
152+
"""Get stack resources information."""
153+
response = self._client.describe_stack_resources(StackName=stack_name).get("StackResources")
154+
return {resource["LogicalResourceId"]: resource for resource in response} # Build dictionary for better query.
155+
150156
@AWSExceptionHandler.handle_client_exception
151157
def get_imagebuilder_stacks(self, next_token=None):
152158
"""List existing imagebuilder stacks."""

cli/src/pcluster/aws/ec2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def get_subnet_vpc(self, subnet_id):
136136
@Cache.cached
137137
def get_subnet_cidr(self, subnet_id):
138138
"""Return cidr block of the given subnet."""
139-
subnets = self._client.describe_subnets(SubnetIds=[subnet_id]).get("Subnets")
139+
subnets = self.describe_subnets([subnet_id])
140140
if subnets:
141141
return subnets[0].get("CidrBlock")
142142
raise AWSClientError(function_name="describe_subnets", message=f"Subnet {subnet_id} not found")

cli/src/pcluster/config/cluster_config.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,8 @@ def __init__(
12611261
self._official_ami = None
12621262
self.imds = imds or TopLevelImds(implied="v1.0")
12631263
self.deployment_settings = deployment_settings
1264+
self.managed_head_node_security_group = None
1265+
self.managed_compute_security_group = None
12641266

12651267
def _register_validators(self, context: ValidatorContext = None): # noqa: D102 #pylint: disable=unused-argument
12661268
self._register_validator(RegionValidator, region=self.region)
@@ -1387,15 +1389,15 @@ def _register_storage_validators(self):
13871389
EfsIdValidator,
13881390
efs_id=storage.file_system_id,
13891391
avail_zones_mapping=self.availability_zones_subnets_mapping,
1390-
are_all_security_groups_customized=self.are_all_security_groups_customized,
1392+
security_groups_by_nodes=self.security_groups_by_nodes,
13911393
)
13921394
else:
13931395
new_storage_count["efs"] += 1
13941396
self._register_validator(
13951397
ExistingFsxNetworkingValidator,
13961398
file_system_ids=list(existing_fsx),
1397-
head_node_subnet_id=self.head_node.networking.subnet_id,
1398-
are_all_security_groups_customized=self.are_all_security_groups_customized,
1399+
subnet_ids=[self.head_node.networking.subnet_id] + self.compute_subnet_ids,
1400+
security_groups_by_nodes=self.security_groups_by_nodes,
13991401
)
14001402

14011403
self._validate_max_storage_count(ebs_count, existing_storage_count, new_storage_count)
@@ -1624,17 +1626,29 @@ def is_dcv_enabled(self):
16241626
return self.head_node.dcv and self.head_node.dcv.enabled
16251627

16261628
@property
1627-
def are_all_security_groups_customized(self):
1629+
def security_groups_by_nodes(self):
16281630
"""Return True if all head node and queues have (additional) security groups specified."""
16291631
head_node_networking = self.head_node.networking
1632+
security_groups_for_head_node = set()
1633+
if head_node_networking.security_groups:
1634+
security_groups_for_head_node.update(set(head_node_networking.security_groups))
1635+
if head_node_networking.additional_security_groups:
1636+
security_groups_for_head_node.update(set(head_node_networking.additional_security_groups))
16301637
if not (head_node_networking.security_groups or head_node_networking.additional_security_groups):
1631-
return False
1638+
security_groups_for_head_node.add(self.managed_head_node_security_group)
1639+
security_groups_for_all_nodes = {frozenset(security_groups_for_head_node)}
16321640
for queue in self.scheduling.queues:
16331641
queue_networking = queue.networking
16341642
if isinstance(queue_networking, _QueueNetworking):
1635-
if not (queue_networking.security_groups or queue_networking.additional_security_groups):
1636-
return False
1637-
return True
1643+
security_groups_for_compute_node = set()
1644+
if queue_networking.security_groups:
1645+
security_groups_for_compute_node.update(set(queue_networking.security_groups))
1646+
else:
1647+
security_groups_for_compute_node.add(self.managed_compute_security_group)
1648+
if queue_networking.additional_security_groups:
1649+
security_groups_for_compute_node.update(set(queue_networking.additional_security_groups))
1650+
security_groups_for_all_nodes.add(frozenset(security_groups_for_compute_node))
1651+
return security_groups_for_all_nodes
16381652

16391653
@property
16401654
def extra_chef_attributes(self):

cli/src/pcluster/models/cluster.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,9 @@ def _validate_and_parse_config(
440440
Cluster._load_additional_instance_type_data(cluster_config_dict)
441441
config = self._load_config(cluster_config_dict)
442442
config.official_ami = self.__official_ami
443+
if context.during_update:
444+
config.managed_head_node_security_group = self.stack.get_resource("HeadNodeSecurityGroup")
445+
config.managed_compute_security_group = self.stack.get_resource("ComputeSecurityGroup")
443446

444447
validation_failures = config.validate(validator_suppressors, context)
445448
if any(f.level.value >= FailureLevel(validation_failure_level).value for f in validation_failures):
@@ -851,7 +854,7 @@ def validate_update_request(
851854
validator_suppressors=validator_suppressors,
852855
validation_failure_level=validation_failure_level,
853856
config_text=target_source_config,
854-
context=ValidatorContext(head_node_instance_id=self.head_node_instance.id),
857+
context=ValidatorContext(head_node_instance_id=self.head_node_instance.id, during_update=True),
855858
)
856859
changes = self._validate_patch(force, target_config)
857860

0 commit comments

Comments
 (0)