1212import re
1313from collections import defaultdict
1414from enum import Enum
15- from ipaddress import ip_network
15+ from ipaddress import collapse_addresses , ip_network
1616from itertools import combinations , product
1717from typing import List
1818
@@ -417,42 +417,64 @@ def _validate(
417417# --------------- Storage validators --------------- #
418418
419419
420- def _check_in_out_access (security_groups_ids , port , is_cidr_optional , protocol = "tcp" ):
420+ def _is_access_allowed (security_groups_ids , subnets , port , all_security_groups , protocol = "tcp" ):
421421 """
422422 Verify given list of security groups to check if they allow in and out access on the given port.
423423
424424 :param security_groups_ids: list of security groups to verify
425425 :param port: port to verify
426- :param is_cidr_optional: if it is True, don't enforce check on CIDR.
426+ :param all_security_groups: all security groups from cluster. This is a set of frozen sets.
427+ Each frozen set contains sg combination of a queue.
427428 :param protocol: the IP protocol to be checked.
428429 :return: True if both in and out access are allowed
429430 :raise: ClientError if a given security group doesn't exist
430431 """
431432 in_access = False
432433 out_access = False
433434
435+ in_ip_ranges = []
436+ out_ip_ranges = []
434437 for sec_group in AWSApi .instance ().ec2 .describe_security_groups (security_groups_ids ):
435438 # Check all inbound rules
436439 for rule in sec_group .get ("IpPermissions" ):
437- if _check_sg_rules_for_port ( rule , port , protocol ) :
438- if is_cidr_optional or rule . get ( "IpRanges" ) or rule . get ( "PrefixListIds" ):
439- in_access = True
440- break
440+ if in_access :
441+ break
442+ if _is_port_allowed_by_sg_rule ( rule , port , protocol ):
443+ in_access = _is_src_or_dst_allowed_by_sg_rule ( rule , all_security_groups , in_ip_ranges )
441444
442445 # Check all outbound rules
443446 for rule in sec_group .get ("IpPermissionsEgress" ):
444- if _check_sg_rules_for_port ( rule , port , protocol ) :
445- if is_cidr_optional or rule . get ( "IpRanges" ) or rule . get ( "PrefixListIds" ):
446- out_access = True
447- break
447+ if out_access :
448+ break
449+ if _is_port_allowed_by_sg_rule ( rule , port , protocol ):
450+ out_access = _is_src_or_dst_allowed_by_sg_rule ( rule , all_security_groups , out_ip_ranges )
448451
449452 if in_access and out_access :
450453 return True
451454
455+ # Rules of ip ranges have to be checked at the end because the union of all ip ranges may cover the subnets,
456+ # even when individual ip ranges do not cover the subnets
457+ in_access = in_access or _are_subnets_covered_by_cidrs (in_ip_ranges , subnets )
458+ out_access = out_access or _are_subnets_covered_by_cidrs (out_ip_ranges , subnets )
459+ return in_access and out_access
460+
461+
462+ def _is_src_or_dst_allowed_by_sg_rule (rule , all_security_groups , ip_ranges ):
463+ if rule .get ("IpRanges" ):
464+ ip_ranges .extend (rule .get ("IpRanges" ))
465+ return False # Ip Ranges have to be checked later. Return False because the rule allowance is not determined.
466+ elif rule .get ("PrefixListIds" ):
467+ return True # Always assume prefix list is properly set for code simplicity
468+ elif rule .get ("UserIdGroupPairs" ):
469+ allowed_security_groups = {
470+ user_id_group_pair .get ("GroupId" ) for user_id_group_pair in rule .get ("UserIdGroupPairs" )
471+ }
472+ # For all cluster nodes, at least one of the security groups attached need to be in the UserIdGroupPairs.
473+ return all (node_security_groups & allowed_security_groups for node_security_groups in all_security_groups )
452474 return False
453475
454476
455- def _check_sg_rules_for_port (rule , port_to_check , protocol ):
477+ def _is_port_allowed_by_sg_rule (rule , port_to_check , protocol ):
456478 """
457479 Verify if the security group rule accepts connections on the given port.
458480
@@ -483,6 +505,23 @@ def _check_sg_rules_for_port(rule, port_to_check, protocol):
483505 return False
484506
485507
508+ def _are_subnets_covered_by_cidrs (ip_ranges , subnets ):
509+ """Verify given list of security groups to check if they allow in and out access on cluster subnet CIDRs."""
510+ # Collapse ip ranges for better performance and correctness
511+ collapsed_ip_ranges = list (collapse_addresses ([ip_network (ip_range ["CidrIp" ]) for ip_range in ip_ranges ]))
512+
513+ for subnet in subnets :
514+ subnet_cidr = ip_network (AWSApi .instance ().ec2 .get_subnet_cidr (subnet ))
515+ covered = False
516+ for ip_range in collapsed_ip_ranges :
517+ if ip_range .supernet_of (subnet_cidr ):
518+ covered = True
519+ break
520+ if not covered :
521+ return False
522+ return True
523+
524+
486525class ExistingFsxNetworkingValidator (Validator ):
487526 """
488527 FSx networking validator.
@@ -504,27 +543,23 @@ def _describe_network_interfaces(self, file_systems):
504543 else :
505544 return {}
506545
507- def _validate (self , file_system_ids , head_node_subnet_id , are_all_security_groups_customized ):
546+ def _validate (self , file_system_ids , subnet_ids , all_security_groups ):
508547 try :
509- # Check to see if there is any existing mt on the fs
510548 file_systems = AWSApi .instance ().fsx .get_file_systems_info (file_system_ids )
511-
512- vpc_id = AWSApi .instance ().ec2 .get_subnet_vpc (head_node_subnet_id )
513-
514- network_interfaces_data = self ._describe_network_interfaces (file_systems )
515-
516- self ._check_file_systems (are_all_security_groups_customized , file_systems , network_interfaces_data , vpc_id )
549+ self ._check_file_systems (all_security_groups , file_systems , subnet_ids )
517550 except AWSClientError as e :
518551 self ._add_failure (str (e ), FailureLevel .ERROR )
519552
520- def _check_file_systems (self , are_all_security_groups_customized , file_systems , network_interfaces_data , vpc_id ):
553+ def _check_file_systems (self , all_security_groups , file_systems , subnet_ids ):
554+ vpc_id = AWSApi .instance ().ec2 .get_subnet_vpc (subnet_ids [0 ])
555+ network_interfaces_data = self ._describe_network_interfaces (file_systems )
521556 for file_system in file_systems :
522557 # Check to see if fs is in the same VPC as the stack
523558 file_system_id = file_system .file_system_id
524559 if file_system .vpc_id != vpc_id :
525560 self ._add_failure (
526561 "Currently only support using FSx file system that is in the same VPC as the cluster. "
527- "The file system provided is in {0}." . format ( file_system .vpc_id ) ,
562+ f "The file system { file_system_id } is in { file_system .vpc_id } ." ,
528563 FailureLevel .ERROR ,
529564 )
530565
@@ -545,7 +580,7 @@ def _check_file_systems(self, are_all_security_groups_customized, file_systems,
545580
546581 for protocol , ports in FSX_PORTS [file_system .file_system_type ].items ():
547582 missing_ports = self ._get_missing_ports (
548- are_all_security_groups_customized , network_interfaces , ports , protocol
583+ all_security_groups , subnet_ids , network_interfaces , ports , protocol
549584 )
550585
551586 if missing_ports :
@@ -557,17 +592,18 @@ def _check_file_systems(self, are_all_security_groups_customized, file_systems,
557592 FailureLevel .ERROR ,
558593 )
559594
560- def _get_missing_ports (self , are_all_security_groups_customized , network_interfaces , ports , protocol ):
595+ def _get_missing_ports (self , all_security_groups , subnet_ids , network_interfaces , ports , protocol ):
561596 missing_ports = []
562597 for port in ports :
563598 fs_access = False
564599 for network_interface in network_interfaces :
565600 # Get list of security group IDs
566601 sg_ids = [sg .get ("GroupId" ) for sg in network_interface .get ("Groups" )]
567- if _check_in_out_access (
602+ if _is_access_allowed (
568603 sg_ids ,
604+ subnet_ids ,
569605 port = port ,
570- is_cidr_optional = are_all_security_groups_customized ,
606+ all_security_groups = all_security_groups ,
571607 protocol = protocol ,
572608 ):
573609 fs_access = True
@@ -743,7 +779,7 @@ class EfsIdValidator(Validator): # TODO add tests
743779 Validate if there are existing mount target in the cluster (head and computes) availability zone
744780 """
745781
746- def _validate (self , efs_id , avail_zones_mapping : dict , are_all_security_groups_customized ):
782+ def _validate (self , efs_id , avail_zones_mapping : dict , all_security_groups ):
747783 availability_zones = avail_zones_mapping .keys ()
748784 if len (availability_zones ) > 1 and not AWSApi .instance ().efs .is_efs_standard (efs_id ):
749785 self ._add_failure (
@@ -760,7 +796,7 @@ def _validate(self, efs_id, avail_zones_mapping: dict, are_all_security_groups_c
760796 if head_node_target_id :
761797 # Get list of security group IDs of the mount target
762798 sg_ids = AWSApi .instance ().efs .get_efs_mount_target_security_groups (head_node_target_id )
763- if not _check_in_out_access (sg_ids , port = 2049 , is_cidr_optional = are_all_security_groups_customized ):
799+ if not _is_access_allowed (sg_ids , subnets , port = 2049 , all_security_groups = all_security_groups ):
764800 self ._add_failure (
765801 "There is an existing Mount Target {0} in the Availability Zone {1} for EFS {2}, "
766802 "but it does not have a security group that allows inbound and outbound rules to support NFS. "
@@ -769,8 +805,6 @@ def _validate(self, efs_id, avail_zones_mapping: dict, are_all_security_groups_c
769805 ),
770806 FailureLevel .ERROR ,
771807 )
772- if not are_all_security_groups_customized :
773- self ._check_cidrs_cover_subnets (head_node_target_id , avail_zone , sg_ids , efs_id , subnets )
774808 else :
775809 if AWSApi .instance ().efs .is_efs_standard (efs_id ):
776810 avail_zones_missing_mount_target_for_efs_standard .append (avail_zone )
@@ -783,42 +817,6 @@ def _validate(self, efs_id, avail_zones_mapping: dict, are_all_security_groups_c
783817 FailureLevel .ERROR ,
784818 )
785819
786- def _check_subnet_access (self , security_groups_ids , subnet_cidr , access_type ):
787- permission = "IpPermissions" if access_type == "in" else "IpPermissionsEgress"
788- access = False
789- for sec_group in security_groups_ids :
790- for rule in sec_group .get (permission ):
791- if rule .get ("PrefixListIds" ):
792- access = True
793- break
794- if rule .get ("IpRanges" ):
795- for ip_range in rule .get ("IpRanges" ):
796- if ip_network (ip_range .get ("CidrIp" )).supernet_of (subnet_cidr ):
797- access = True
798- break
799- return access
800-
801- def _check_cidrs_cover_subnets (self , head_node_target_id , avail_zone , security_groups_ids , efs_id , subnets ):
802- """Verify given list of security groups to check if they allow in and out access on cluster subnet CIDRs."""
803- security_groups_ids = AWSApi .instance ().ec2 .describe_security_groups (security_groups_ids )
804- for subnet in subnets :
805- subnet_cidr = ip_network (AWSApi .instance ().ec2 .get_subnet_cidr (subnet ))
806- in_access , out_access = self ._check_subnet_access (
807- security_groups_ids , subnet_cidr , "in"
808- ), self ._check_subnet_access (security_groups_ids , subnet_cidr , "out" )
809-
810- if not in_access or not out_access :
811- self ._add_failure (
812- "There is an existing Mount Target {0} in the Availability Zone {1} for EFS {2}, "
813- "but it does not have a security group that allows inbound and outbound rules to allow traffic of "
814- "subnet {3}. Please modify the Mount Target's security group, to allow traffic on subnet." .format (
815- head_node_target_id , avail_zone , efs_id , subnet
816- ),
817- FailureLevel .WARNING ,
818- )
819-
820- return False
821-
822820
823821class SharedStorageNameValidator (Validator ):
824822 """
0 commit comments