Skip to content

Commit dff51f7

Browse files
temporary commit
Signed-off-by: Hanwen <[email protected]>
1 parent b32710d commit dff51f7

File tree

10 files changed

+276
-123
lines changed

10 files changed

+276
-123
lines changed

cli/src/pcluster/aws/aws_resources.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self, stack_data: dict):
3737
self._params = self._stack_data.get("Parameters", [])
3838
self.tags = self._stack_data.get("Tags", [])
3939
self.outputs = self._stack_data.get("Outputs", [])
40+
self.__resources = None
4041

4142
@property
4243
def id(self):
@@ -48,6 +49,15 @@ def name(self):
4849
"""Return the name of the stack."""
4950
return self._stack_data.get("StackName")
5051

52+
@property
53+
def resources(self):
54+
"""Return the resources of the stack."""
55+
if not self.__resources:
56+
from pcluster.aws.aws_api import AWSApi
57+
58+
self.__resources = AWSApi.instance().cfn.describe_stack_resources(self.name)
59+
return self.__resources
60+
5161
@property
5262
def status(self):
5363
"""Return the status of the stack."""
@@ -90,6 +100,10 @@ def _get_param(self, key_name):
90100
param_value = next((par["ParameterValue"] for par in self._params if par["ParameterKey"] == key_name), None)
91101
return None if param_value is None else param_value.strip()
92102

103+
def get_resource(self, resource_logical_id: str):
104+
"""Return the resource information."""
105+
return self.resources.get(resource_logical_id)
106+
93107

94108
class InstanceInfo:
95109
"""Object to store Instance information, initialized with a describe_instances call."""

cli/src/pcluster/aws/cfn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ def describe_stack_resource(self, stack_name: str, logic_resource_id: str):
147147
function_name="describe_stack_resource", message=f"No resource {logic_resource_id} found."
148148
)
149149

150+
@AWSExceptionHandler.handle_client_exception
151+
def describe_stack_resources(self, stack_name: str):
152+
"""Get stack resources information."""
153+
response = self._client.describe_stack_resources(StackName=stack_name).get("StackResources")
154+
return {resource["LogicalResourceId"]: resource for resource in response} # Build dictionary for better query.
155+
150156
@AWSExceptionHandler.handle_client_exception
151157
def get_imagebuilder_stacks(self, next_token=None):
152158
"""List existing imagebuilder stacks."""

cli/src/pcluster/aws/ec2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def get_subnet_vpc(self, subnet_id):
136136
@Cache.cached
137137
def get_subnet_cidr(self, subnet_id):
138138
"""Return cidr block of the given subnet."""
139-
subnets = self._client.describe_subnets(SubnetIds=[subnet_id]).get("Subnets")
139+
subnets = self.describe_subnets([subnet_id])
140140
if subnets:
141141
return subnets[0].get("CidrBlock")
142142
raise AWSClientError(function_name="describe_subnets", message=f"Subnet {subnet_id} not found")

cli/src/pcluster/config/cluster_config.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,8 @@ def __init__(
12611261
self._official_ami = None
12621262
self.imds = imds or TopLevelImds(implied="v1.0")
12631263
self.deployment_settings = deployment_settings
1264+
self.managed_head_node_security_group = None
1265+
self.managed_compute_security_group = None
12641266

12651267
def _register_validators(self, context: ValidatorContext = None): # noqa: D102 #pylint: disable=unused-argument
12661268
self._register_validator(RegionValidator, region=self.region)
@@ -1387,15 +1389,15 @@ def _register_storage_validators(self):
13871389
EfsIdValidator,
13881390
efs_id=storage.file_system_id,
13891391
avail_zones_mapping=self.availability_zones_subnets_mapping,
1390-
are_all_security_groups_customized=self.are_all_security_groups_customized,
1392+
all_security_groups=self.all_security_groups,
13911393
)
13921394
else:
13931395
new_storage_count["efs"] += 1
13941396
self._register_validator(
13951397
ExistingFsxNetworkingValidator,
13961398
file_system_ids=list(existing_fsx),
1397-
head_node_subnet_id=self.head_node.networking.subnet_id,
1398-
are_all_security_groups_customized=self.are_all_security_groups_customized,
1399+
subnet_ids=[self.head_node.networking.subnet_id] + self.compute_subnet_ids,
1400+
all_security_groups=self.all_security_groups,
13991401
)
14001402

14011403
self._validate_max_storage_count(ebs_count, existing_storage_count, new_storage_count)
@@ -1624,17 +1626,29 @@ def is_dcv_enabled(self):
16241626
return self.head_node.dcv and self.head_node.dcv.enabled
16251627

16261628
@property
1627-
def are_all_security_groups_customized(self):
1629+
def all_security_groups(self):
16281630
"""Return True if all head node and queues have (additional) security groups specified."""
16291631
head_node_networking = self.head_node.networking
1632+
security_groups_for_head_node = set()
1633+
if head_node_networking.security_groups:
1634+
security_groups_for_head_node.update(set(head_node_networking.security_groups))
1635+
if head_node_networking.additional_security_groups:
1636+
security_groups_for_head_node.update(set(head_node_networking.additional_security_groups))
16301637
if not (head_node_networking.security_groups or head_node_networking.additional_security_groups):
1631-
return False
1638+
security_groups_for_head_node.add(self.managed_head_node_security_group)
1639+
security_groups_for_all_nodes = {frozenset(security_groups_for_head_node)}
16321640
for queue in self.scheduling.queues:
16331641
queue_networking = queue.networking
16341642
if isinstance(queue_networking, _QueueNetworking):
1635-
if not (queue_networking.security_groups or queue_networking.additional_security_groups):
1636-
return False
1637-
return True
1643+
security_groups_for_compute_node = set()
1644+
if queue_networking.security_groups:
1645+
security_groups_for_compute_node.update(set(queue_networking.security_groups))
1646+
else:
1647+
security_groups_for_compute_node.add(self.managed_compute_security_group)
1648+
if queue_networking.additional_security_groups:
1649+
security_groups_for_compute_node.update(set(queue_networking.additional_security_groups))
1650+
security_groups_for_all_nodes.add(frozenset(security_groups_for_compute_node))
1651+
return security_groups_for_all_nodes
16381652

16391653
@property
16401654
def extra_chef_attributes(self):

cli/src/pcluster/models/cluster.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,9 @@ def _validate_and_parse_config(
440440
Cluster._load_additional_instance_type_data(cluster_config_dict)
441441
config = self._load_config(cluster_config_dict)
442442
config.official_ami = self.__official_ami
443+
if context.during_update:
444+
config.managed_head_node_security_group = self.stack.get_resource("HeadNodeSecurityGroup")
445+
config.managed_compute_security_group = self.stack.get_resource("ComputeSecurityGroup")
443446

444447
validation_failures = config.validate(validator_suppressors, context)
445448
if any(f.level.value >= FailureLevel(validation_failure_level).value for f in validation_failures):
@@ -851,7 +854,7 @@ def validate_update_request(
851854
validator_suppressors=validator_suppressors,
852855
validation_failure_level=validation_failure_level,
853856
config_text=target_source_config,
854-
context=ValidatorContext(head_node_instance_id=self.head_node_instance.id),
857+
context=ValidatorContext(head_node_instance_id=self.head_node_instance.id, during_update=True),
855858
)
856859
changes = self._validate_patch(force, target_config)
857860

cli/src/pcluster/validators/cluster_validators.py

Lines changed: 64 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import re
1313
from collections import defaultdict
1414
from enum import Enum
15-
from ipaddress import ip_network
15+
from ipaddress import collapse_addresses, ip_network
1616
from itertools import combinations, product
1717
from typing import List
1818

@@ -417,42 +417,64 @@ def _validate(
417417
# --------------- Storage validators --------------- #
418418

419419

420-
def _check_in_out_access(security_groups_ids, port, is_cidr_optional, protocol="tcp"):
420+
def _is_access_allowed(security_groups_ids, subnets, port, all_security_groups, protocol="tcp"):
421421
"""
422422
Verify given list of security groups to check if they allow in and out access on the given port.
423423
424424
:param security_groups_ids: list of security groups to verify
425425
:param port: port to verify
426-
:param is_cidr_optional: if it is True, don't enforce check on CIDR.
426+
:param all_security_groups: all security groups from cluster. This is a set of frozen sets.
427+
Each frozen set contains sg combination of a queue.
427428
:param protocol: the IP protocol to be checked.
428429
:return: True if both in and out access are allowed
429430
:raise: ClientError if a given security group doesn't exist
430431
"""
431432
in_access = False
432433
out_access = False
433434

435+
in_ip_ranges = []
436+
out_ip_ranges = []
434437
for sec_group in AWSApi.instance().ec2.describe_security_groups(security_groups_ids):
435438
# Check all inbound rules
436439
for rule in sec_group.get("IpPermissions"):
437-
if _check_sg_rules_for_port(rule, port, protocol):
438-
if is_cidr_optional or rule.get("IpRanges") or rule.get("PrefixListIds"):
439-
in_access = True
440-
break
440+
if in_access:
441+
break
442+
if _is_port_allowed_by_sg_rule(rule, port, protocol):
443+
in_access = _is_src_or_dst_allowed_by_sg_rule(rule, all_security_groups, in_ip_ranges)
441444

442445
# Check all outbound rules
443446
for rule in sec_group.get("IpPermissionsEgress"):
444-
if _check_sg_rules_for_port(rule, port, protocol):
445-
if is_cidr_optional or rule.get("IpRanges") or rule.get("PrefixListIds"):
446-
out_access = True
447-
break
447+
if out_access:
448+
break
449+
if _is_port_allowed_by_sg_rule(rule, port, protocol):
450+
out_access = _is_src_or_dst_allowed_by_sg_rule(rule, all_security_groups, out_ip_ranges)
448451

449452
if in_access and out_access:
450453
return True
451454

455+
# Rules of ip ranges have to be checked at the end because the union of all ip ranges may cover the subnets,
456+
# even when individual ip ranges do not cover the subnets
457+
in_access = in_access or _are_subnets_covered_by_cidrs(in_ip_ranges, subnets)
458+
out_access = out_access or _are_subnets_covered_by_cidrs(out_ip_ranges, subnets)
459+
return in_access and out_access
460+
461+
462+
def _is_src_or_dst_allowed_by_sg_rule(rule, all_security_groups, ip_ranges):
463+
if rule.get("IpRanges"):
464+
ip_ranges.extend(rule.get("IpRanges"))
465+
return False # Ip Ranges have to be checked later. Return False because the rule allowance is not determined.
466+
elif rule.get("PrefixListIds"):
467+
return True # Always assume prefix list is properly set for code simplicity
468+
elif rule.get("UserIdGroupPairs"):
469+
allowed_security_groups = {
470+
user_id_group_pair.get("GroupId") for user_id_group_pair in rule.get("UserIdGroupPairs")
471+
}
472+
# For all cluster nodes, at least one of the security groups attached need to be in the UserIdGroupPairs.
473+
return all(node_security_groups & allowed_security_groups for node_security_groups in all_security_groups)
452474
return False
453475

454476

455-
def _check_sg_rules_for_port(rule, port_to_check, protocol):
477+
def _is_port_allowed_by_sg_rule(rule, port_to_check, protocol):
456478
"""
457479
Verify if the security group rule accepts connections on the given port.
458480
@@ -483,6 +505,23 @@ def _check_sg_rules_for_port(rule, port_to_check, protocol):
483505
return False
484506

485507

508+
def _are_subnets_covered_by_cidrs(ip_ranges, subnets):
509+
"""Verify given list of security groups to check if they allow in and out access on cluster subnet CIDRs."""
510+
# Collapse ip ranges for better performance and correctness
511+
collapsed_ip_ranges = list(collapse_addresses([ip_network(ip_range["CidrIp"]) for ip_range in ip_ranges]))
512+
513+
for subnet in subnets:
514+
subnet_cidr = ip_network(AWSApi.instance().ec2.get_subnet_cidr(subnet))
515+
covered = False
516+
for ip_range in collapsed_ip_ranges:
517+
if ip_range.supernet_of(subnet_cidr):
518+
covered = True
519+
break
520+
if not covered:
521+
return False
522+
return True
523+
524+
486525
class ExistingFsxNetworkingValidator(Validator):
487526
"""
488527
FSx networking validator.
@@ -504,27 +543,23 @@ def _describe_network_interfaces(self, file_systems):
504543
else:
505544
return {}
506545

507-
def _validate(self, file_system_ids, head_node_subnet_id, are_all_security_groups_customized):
546+
def _validate(self, file_system_ids, subnet_ids, all_security_groups):
508547
try:
509-
# Check to see if there is any existing mt on the fs
510548
file_systems = AWSApi.instance().fsx.get_file_systems_info(file_system_ids)
511-
512-
vpc_id = AWSApi.instance().ec2.get_subnet_vpc(head_node_subnet_id)
513-
514-
network_interfaces_data = self._describe_network_interfaces(file_systems)
515-
516-
self._check_file_systems(are_all_security_groups_customized, file_systems, network_interfaces_data, vpc_id)
549+
self._check_file_systems(all_security_groups, file_systems, subnet_ids)
517550
except AWSClientError as e:
518551
self._add_failure(str(e), FailureLevel.ERROR)
519552

520-
def _check_file_systems(self, are_all_security_groups_customized, file_systems, network_interfaces_data, vpc_id):
553+
def _check_file_systems(self, all_security_groups, file_systems, subnet_ids):
554+
vpc_id = AWSApi.instance().ec2.get_subnet_vpc(subnet_ids[0])
555+
network_interfaces_data = self._describe_network_interfaces(file_systems)
521556
for file_system in file_systems:
522557
# Check to see if fs is in the same VPC as the stack
523558
file_system_id = file_system.file_system_id
524559
if file_system.vpc_id != vpc_id:
525560
self._add_failure(
526561
"Currently only support using FSx file system that is in the same VPC as the cluster. "
527-
"The file system provided is in {0}.".format(file_system.vpc_id),
562+
f"The file system {file_system_id} is in {file_system.vpc_id}.",
528563
FailureLevel.ERROR,
529564
)
530565

@@ -545,7 +580,7 @@ def _check_file_systems(self, are_all_security_groups_customized, file_systems,
545580

546581
for protocol, ports in FSX_PORTS[file_system.file_system_type].items():
547582
missing_ports = self._get_missing_ports(
548-
are_all_security_groups_customized, network_interfaces, ports, protocol
583+
all_security_groups, subnet_ids, network_interfaces, ports, protocol
549584
)
550585

551586
if missing_ports:
@@ -557,17 +592,18 @@ def _check_file_systems(self, are_all_security_groups_customized, file_systems,
557592
FailureLevel.ERROR,
558593
)
559594

560-
def _get_missing_ports(self, are_all_security_groups_customized, network_interfaces, ports, protocol):
595+
def _get_missing_ports(self, all_security_groups, subnet_ids, network_interfaces, ports, protocol):
561596
missing_ports = []
562597
for port in ports:
563598
fs_access = False
564599
for network_interface in network_interfaces:
565600
# Get list of security group IDs
566601
sg_ids = [sg.get("GroupId") for sg in network_interface.get("Groups")]
567-
if _check_in_out_access(
602+
if _is_access_allowed(
568603
sg_ids,
604+
subnet_ids,
569605
port=port,
570-
is_cidr_optional=are_all_security_groups_customized,
606+
all_security_groups=all_security_groups,
571607
protocol=protocol,
572608
):
573609
fs_access = True
@@ -743,7 +779,7 @@ class EfsIdValidator(Validator): # TODO add tests
743779
Validate if there are existing mount target in the cluster (head and computes) availability zone
744780
"""
745781

746-
def _validate(self, efs_id, avail_zones_mapping: dict, are_all_security_groups_customized):
782+
def _validate(self, efs_id, avail_zones_mapping: dict, all_security_groups):
747783
availability_zones = avail_zones_mapping.keys()
748784
if len(availability_zones) > 1 and not AWSApi.instance().efs.is_efs_standard(efs_id):
749785
self._add_failure(
@@ -760,7 +796,7 @@ def _validate(self, efs_id, avail_zones_mapping: dict, are_all_security_groups_c
760796
if head_node_target_id:
761797
# Get list of security group IDs of the mount target
762798
sg_ids = AWSApi.instance().efs.get_efs_mount_target_security_groups(head_node_target_id)
763-
if not _check_in_out_access(sg_ids, port=2049, is_cidr_optional=are_all_security_groups_customized):
799+
if not _is_access_allowed(sg_ids, subnets, port=2049, all_security_groups=all_security_groups):
764800
self._add_failure(
765801
"There is an existing Mount Target {0} in the Availability Zone {1} for EFS {2}, "
766802
"but it does not have a security group that allows inbound and outbound rules to support NFS. "
@@ -769,8 +805,6 @@ def _validate(self, efs_id, avail_zones_mapping: dict, are_all_security_groups_c
769805
),
770806
FailureLevel.ERROR,
771807
)
772-
if not are_all_security_groups_customized:
773-
self._check_cidrs_cover_subnets(head_node_target_id, avail_zone, sg_ids, efs_id, subnets)
774808
else:
775809
if AWSApi.instance().efs.is_efs_standard(efs_id):
776810
avail_zones_missing_mount_target_for_efs_standard.append(avail_zone)
@@ -783,42 +817,6 @@ def _validate(self, efs_id, avail_zones_mapping: dict, are_all_security_groups_c
783817
FailureLevel.ERROR,
784818
)
785819

786-
def _check_subnet_access(self, security_groups_ids, subnet_cidr, access_type):
787-
permission = "IpPermissions" if access_type == "in" else "IpPermissionsEgress"
788-
access = False
789-
for sec_group in security_groups_ids:
790-
for rule in sec_group.get(permission):
791-
if rule.get("PrefixListIds"):
792-
access = True
793-
break
794-
if rule.get("IpRanges"):
795-
for ip_range in rule.get("IpRanges"):
796-
if ip_network(ip_range.get("CidrIp")).supernet_of(subnet_cidr):
797-
access = True
798-
break
799-
return access
800-
801-
def _check_cidrs_cover_subnets(self, head_node_target_id, avail_zone, security_groups_ids, efs_id, subnets):
802-
"""Verify given list of security groups to check if they allow in and out access on cluster subnet CIDRs."""
803-
security_groups_ids = AWSApi.instance().ec2.describe_security_groups(security_groups_ids)
804-
for subnet in subnets:
805-
subnet_cidr = ip_network(AWSApi.instance().ec2.get_subnet_cidr(subnet))
806-
in_access, out_access = self._check_subnet_access(
807-
security_groups_ids, subnet_cidr, "in"
808-
), self._check_subnet_access(security_groups_ids, subnet_cidr, "out")
809-
810-
if not in_access or not out_access:
811-
self._add_failure(
812-
"There is an existing Mount Target {0} in the Availability Zone {1} for EFS {2}, "
813-
"but it does not have a security group that allows inbound and outbound rules to allow traffic of "
814-
"subnet {3}. Please modify the Mount Target's security group, to allow traffic on subnet.".format(
815-
head_node_target_id, avail_zone, efs_id, subnet
816-
),
817-
FailureLevel.WARNING,
818-
)
819-
820-
return False
821-
822820

823821
class SharedStorageNameValidator(Validator):
824822
"""

cli/src/pcluster/validators/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,6 @@ def _validate(self, *args, **kwargs):
6969
class ValidatorContext:
7070
"""Context containing information about cluster environment meant to be passed to validators."""
7171

72-
def __init__(self, head_node_instance_id: str = None):
72+
def __init__(self, head_node_instance_id: str = None, during_update: bool = None):
7373
self.head_node_instance_id = head_node_instance_id
74+
self.during_update = during_update

0 commit comments

Comments
 (0)