diff --git a/pyproject.toml b/pyproject.toml index 97c3e09e7..c221f0937 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ paste = {optional = true, version = "*"} pyopenssl = "*" python-dateutil = "*" pytz = "*" +pydantic = {version = ">=1.7.4"} "repoze.who" = {optional = true, version = "*"} requests = "^2" xmlschema = ">=1.2.1" diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py index 4c0ab1511..b1582174c 100644 --- a/src/saml2/assertion.py +++ b/src/saml2/assertion.py @@ -1,15 +1,35 @@ #!/usr/bin/env python +from __future__ import annotations + import copy import importlib import logging import re from warnings import warn as _warn +from typing import Any +from typing import Dict +from typing import List +from typing import Literal +from typing import Mapping +from typing import Optional +from typing import Type +from typing import TypedDict +from typing import TypeVar +from typing import Union +from warnings import warn as _warn + +from pydantic import BaseModel +from pydantic import ValidationError +from pydantic import validator + from saml2 import saml from saml2 import xmlenc +from saml2.attribute_converter import AttributeConverter from saml2.attribute_converter import ac_factory from saml2.attribute_converter import from_local from saml2.attribute_converter import get_local_name +from saml2.mdstore import MetadataStore from saml2.s_utils import MissingValue from saml2.s_utils import assertion_factory from saml2.s_utils import factory @@ -17,45 +37,80 @@ from saml2.saml import NAME_FORMAT_URI from saml2.time_util import in_a_while from saml2.time_util import instant +from saml2.typing import AttributeAsDict +from saml2.typing import AttributeValues +from saml2.typing import AttributeValuesStrict logger = logging.getLogger(__name__) +extra_logger = logger.getChild("extra") -def _filter_values(vals, vlist=None, must=False): - """Removes values from *vals* that does not appear in vlist +class EntityCategoryMatcher(BaseModel): + """ + Part of EntityCategoryRule. - :param vals: The values that are to be filtered - :param vlist: required or optional value - :param must: Whether the allowed values must appear - :return: The set of values after filtering + Decides, based on a list of entity categories for an SP, if this rule applies to the SP or not. """ - if not vlist: # No value specified equals any value - return vals + required: List[str] # List of entity category URIs that must be present in the SP's entity categories + conflicts: List[str] = [] # List of entity category URIs that must not be present in the SP's entity categories - if vals is None: # cannot iterate over None, return early - return vals + def matches(self, sp_ecs: List[str]) -> bool: + """Return True if all our entity categories is present in the list of SP entity categories""" + _conflicts = self._conflicts(sp_ecs) + if _conflicts: + extra_logger.debug(f"Not matching, SP entity categories in conflict with {self.conflicts}") + return False + if self.required == [""]: + # A rule with this matching criteria results in attributes always being released + return True + return all([x in sp_ecs for x in self.required]) - if isinstance(vlist, str): - vlist = [vlist] + def _conflicts(self, sp_ecs: List[str]) -> bool: + """Return True if any of the SP's entity categories are present in `conflicts'.""" + return any([x in sp_ecs for x in self.conflicts]) - res = [] - for val in vlist: - if val in vals: - res.append(val) +class EntityCategoryRule(BaseModel): + """A rule to decide whether or not to add a list of attributes for release to an SP.""" - if must: - if res: - return res - else: - raise MissingValue("Required attribute value missing") - else: - return res + match: EntityCategoryMatcher + attributes: List[str] # attributes to release if this rule matches (friendly names) + only_required: bool = False # If this rule matches, only include the required attributes for the SP + + @validator("attributes") + def lowercase_attribute_names(cls, v: List[str]): + """Make sure all attribute names are lower case, for easier comparison later.""" + return [x.lower() for x in v] + + +# The regexps are an optional "allow-list" for values. If regexps are provided, one of them has to +# match a value for it to be released. +AllowedAttributeValue = re.Pattern[str] +AttributeRestrictions = dict[str, Optional[list[AllowedAttributeValue]]] + + +def _filter_values(values: list[str], allowed_values: list[str], must: bool = False) -> list[str]: + """Removes values from *values* that does not appear in allowed_values. + + :param vals: The values that are to be filtered + :param allowed_values: required or optional values + :param must: Whether the allowed values must appear + :return: The set of values after filtering + """ + + if not allowed_values: # No value specified equals any value + return values + + res = [x for x in values if x in allowed_values] + + if must and not res: + raise MissingValue("Required attribute value missing") + return res -def _match(attr, ava): +def _match(attr: str, ava: AttributeValues) -> Optional[str]: if attr in ava: return attr @@ -70,20 +125,48 @@ def _match(attr, ava): return None -def filter_on_attributes(ava, required=None, optional=None, acs=None, fail_on_unfulfilled_requirements=True): - """Filter +AttributesAsDicts = list[AttributeAsDict] + + +def filter_on_attributes( + ava: AttributeValues, + required: Optional[AttributesAsDicts] = None, + optional: Optional[AttributesAsDicts] = None, + acs: Optional[list[AttributeConverter]] = None, + fail_on_unfulfilled_requirements: bool = True, +) -> AttributeValues: + """Filter attributes in `ava', returning a new instance of AttributeValues. + + * Ensure that all the values in the attribute value assertion are allowed + * Ensure that all the required attributes are present (else raise MissingValue) :param ava: An attribute value assertion as a dictionary - :param required: list of RequestedAttribute instances defined to be - required - :param optional: list of RequestedAttribute instances defined to be - optional + :param required: list of attributes defined to be required + :param optional: list of attributes defined to be optional :param fail_on_unfulfilled_requirements: If required attributes are missing fail or fail not depending on this parameter. :return: The modified attribute value assertion """ - def _match_attr_name(attr, ava): + def _filter_value_or_values( + val: Union[list[str], str], allowed_values: list[str], must: bool = False + ) -> Union[str, list[str]]: + """Convert single value to list of values before calling _filter_values.""" + values: list[str] + if isinstance(val, str): + values = [val] + else: + values = val + res = _filter_values(values, allowed_values, must) + return res + + def _identify_attribute(attr: AttributeAsDict, ava: AttributeValues) -> Optional[str]: + """Find and identify `attr' in `ava'. + + The attribute we want to work with might be identified by its name, name_format, + friendly_name or it's URI. This function tries to find the attribute in `ava' and + returns the friendly_name of the attribute in `ava'. + """ name = attr["name"].lower() name_format = attr.get("name_format") friendly_name = attr.get("friendly_name") @@ -95,38 +178,52 @@ def _match_attr_name(attr, ava): ) return _fn - def _apply_attr_value_restrictions(attr, res, must=False): - values = [av["text"] for av in attr.get("attribute_value", [])] + def _apply_attr_value_restrictions( + friendly_name: str, attr: AttributeAsDict, res: AttributeValuesStrict, must: bool = False + ): + """Add the attribute `friendly_name` to `res`, filtering its values if necessary.""" + _av_list = attr.get("attribute_value", []) + assert _av_list is not None # please mypy, the get() above defaults to empty list + allowed_values = [av["text"] for av in _av_list] - try: - res[_fn].extend(_filter_values(ava[_fn], values)) - except KeyError: - # ignore duplicate RequestedAttribute entries - val = _filter_values(ava[_fn], values) - res[_fn] = val if val is not None else [] + _values = _filter_value_or_values(ava[friendly_name], allowed_values, must) + if not _values: + return # nothing to add - return _filter_values(ava[_fn], values, must) + if friendly_name not in res: + res[friendly_name] = [] - res = {} + res[friendly_name].extend(_values) + + new_ava = AttributeValuesStrict({}) if required is None: required = [] for attr in required: - _fn = _match_attr_name(attr, ava) + _fn = _identify_attribute(attr, ava) if _fn: - _apply_attr_value_restrictions(attr, res, True) + _apply_attr_value_restrictions(_fn, attr, new_ava, True) elif fail_on_unfulfilled_requirements: - desc = f"Required attribute missing: '{attr['name']}'" - raise MissingValue(desc) + raise MissingValue(f"Required attribute missing: '{attr['name']}'") if optional is None: optional = [] for attr in optional: - _fn = _match_attr_name(attr, ava) + _fn = _identify_attribute(attr, ava) if _fn: - _apply_attr_value_restrictions(attr, res, False) + _apply_attr_value_restrictions(_fn, attr, new_ava, False) + + # TODO: Kludge to turn lists-of-strings back into strings if the data was given + # as a string in `ava`. This is needed to make the tests pass, but maybe it + # would be preferable to declare ava to only have lists of strings? + res = AttributeValues({}) + for this in new_ava.keys(): + if isinstance(ava[this], str): + res[this] = new_ava[this][0] + else: + res[this] = new_ava[this] return res @@ -213,7 +310,9 @@ def filter_on_wire_representation(ava, acs, required=None, optional=None): return res -def filter_attribute_value_assertions(ava, attribute_restrictions=None): +def filter_attribute_value_assertions( + ava: AttributeValues, attribute_restrictions: Optional[AttributeRestrictions] = None +) -> AttributeValues: """Will weed out attribute values and values according to the rules defined in the attribute restrictions. If filtering results in an attribute without values, then the attribute is removed from the @@ -225,10 +324,11 @@ def filter_attribute_value_assertions(ava, attribute_restrictions=None): :return: The modified attribute value assertion """ if not attribute_restrictions: + # If there are no restrictions, release everything we have return ava for attr, vals in list(ava.items()): - _attr = attr.lower() + _attr = attr.lower() # TODO: check if needed try: _rests = attribute_restrictions[_attr] except KeyError: @@ -238,7 +338,7 @@ def filter_attribute_value_assertions(ava, attribute_restrictions=None): continue if isinstance(vals, str): vals = [vals] - rvals = [] + rvals: list[str] = [] for restr in _rests: for val in vals: if restr.match(val): @@ -264,66 +364,215 @@ def restriction_from_attribute_spec(attributes): return restr -def compile(restrictions): - """This is only for IdPs or AAs, and it's about limiting what - is returned to the SP. - In the configuration file, restrictions on which values that - can be returned are specified with the help of regular expressions. - This function goes through and pre-compiles the regular expressions. +class EntityCategoryPolicy(BaseModel): + """Holder of rule sets for entity categories. - :param restrictions: policy configuration - :return: The assertion with the string specification replaced with - a compiled regular expression. + `categories' keys are category names (currently also module names from where the rules are loaded) + `categories' values are lists of rules for that category. """ - for who, spec in restrictions.items(): - spec = spec or {} - entity_categories = spec.get("entity_categories", []) - ecs = [] - for cat in entity_categories: + categories: dict[str, list[EntityCategoryRule]] + + def __str__(self) -> str: + return f"<{self.__class__.__name__}: {self.categories.keys()}>" + + @classmethod + def from_module_names(cls: Type["EntityCategoryPolicy"], entity_categories: List[str]) -> "EntityCategoryPolicy": + """Load a list of rules for a category. + + In the current implementation, the rules are loaded from a module - one module per category. + + The old format was to have the rules in the module's RELEASE dictionary, and the ONLY_REQUIRED dictionary. + + The new format is to load a list of rules from the RESTRICTIONS in the module, and use pydantic to validate + and convert the rules to EntityCategoryRule objects. + """ + res: dict[str, list[EntityCategoryRule]] = {} + for category in entity_categories: try: - _mod = importlib.import_module(cat) + _mod = importlib.import_module(category) except ImportError: - _mod = importlib.import_module(f"saml2.entity_category.{cat}") + _mod = importlib.import_module(f"saml2.entity_category.{category}") - _ec = {} + # `rules' is the list of rules loaded from this module + rules: list[EntityCategoryRule] = [] + + # Old format, load rules from RELEASE and ONLY_REQUIRED (two dictionaries) for key, items in _mod.RELEASE.items(): alist = [k.lower() for k in items] _only_required = getattr(_mod, "ONLY_REQUIRED", {}).get(key, False) - _no_aggregation = getattr(_mod, "NO_AGGREGATION", {}).get(key, False) - _ec[key] = (alist, _only_required, _no_aggregation) - ecs.append(_ec) - spec["entity_categories"] = ecs or None - attribute_restrictions = spec.get("attribute_restrictions") or {} - _attribute_restrictions = {} - for key, values in attribute_restrictions.items(): - lkey = key.lower() - values = [] if not values else values - _attribute_restrictions[lkey] = [re.compile(value) for value in values] or None - spec["attribute_restrictions"] = _attribute_restrictions or None + # Convert tuples to a list of strings, and a single string to a list of one string + _key_as_list: List[str] + if isinstance(key, str): + _key_as_list = [key] + else: + _key_as_list = list(key) + + rules.append( + EntityCategoryRule( + match=EntityCategoryMatcher(required=_key_as_list, conflicts=[]), + attributes=alist, + only_required=_only_required, + ) + ) + + # New format, load rules from RESTRICTIONS (a list) + if hasattr(_mod, "RESTRICTIONS") and isinstance(_mod.RESTRICTIONS, list): + for this in _mod.RESTRICTIONS: + try: + rules.append(EntityCategoryRule.parse_obj(this)) + except ValidationError: + logger.warning(f"Invalid entity category rule: {this}") + raise + + res[category] = rules + return cls(categories=res) + + def attribute_restrictions_for_sp( + self, + acs: List[AttributeConverter], + sp_entity_id: Optional[str] = None, + mds: Optional[MetadataStore] = None, # TODO: Possibly a 'MetaData' instance (parent of MetadataStore) + required: Optional[List[AttributeAsDict]] = None, + ) -> AttributeRestrictions: + """ + Compile the attribute restrictions for a given SP. + + Attribute restrictions are expressed as a dict with attribute names as keys + and optionally a list of regular expressions as values. + + If the value is a list of regular expressions, then the value of the attribute must match + one of the regular expressions. Otherwise the attribute is not allowed (meaning will not be released). + + If the value is None, then all values are allowed (think of it as "no restrictions apply"). + """ + restrictions: AttributeRestrictions = {} + + required_friendly_names: List[str] = [] + if required is not None: + for d in required: + # The dicts in 'required' can have a 'friendly_name', or a 'name' and a 'name_format'. + # See the documentation of the RequiredAttribute type. + _friendly_name: Optional[str] = d.get("friendly_name") + if not _friendly_name: + _friendly_name = get_local_name(acs=acs, attr=d["name"], name_format=d["name_format"]) + assert isinstance(_friendly_name, str) + required_friendly_names.append(_friendly_name.lower()) + + if not mds: + return restrictions + + sp_categories: List[str] = mds.entity_categories(sp_entity_id) + + extra_logger.debug( + f"Compiling attributes to release based on SP {sp_entity_id} entity categories: {sp_categories}" + ) + extra_logger.debug(f"Required attributes for this SP: {required_friendly_names}") + + for rule_set in self.categories.values(): + for this_rule in rule_set: + _matches = this_rule.match.matches(sp_categories) + extra_logger.debug(f"Rule {this_rule.match}, matches: {_matches}") + if _matches: + if this_rule.only_required: + attrs = [a for a in this_rule.attributes if a in required_friendly_names] + _not_adding = [a for a in this_rule.attributes if a not in required_friendly_names] + extra_logger.debug(f"Adding only required attributes: {attrs}, not adding: {_not_adding}") + else: + attrs = this_rule.attributes + extra_logger.debug(f"Adding attributes: {attrs}") + + for attr in attrs: + restrictions[attr] = None + + if not restrictions: + restrictions[""] = None + + logger.debug(f"Compiled attribute restrictions: {restrictions}") + return restrictions + + +PolicyConfigKey = Union[str, Literal["default"]] + - return restrictions +class PolicyConfigValue(BaseModel): + lifetime: Optional[Any] + attribute_restrictions: Optional[AttributeRestrictions] + name_form: Optional[str] + nameid_format: Optional[str] + entity_categories: EntityCategoryPolicy + sign: Optional[Union[Literal["response"], Literal["assertion"], Literal["on_demand"]]] + fail_on_missing_requested: Optional[bool] + + class Config: + arbitrary_types_allowed = True # allow re.Pattern as type in AttributeRestrictions + + +PolicyConfig = dict[PolicyConfigKey, PolicyConfigValue] class Policy: """Handles restrictions on assertions.""" - def __init__(self, restrictions=None, mds=None): + def __init__(self, restrictions: Optional[Mapping[str, Any]] = None, mds: Optional[MetadataStore] = None): self.metadata_store = mds self._restrictions = self.setup_restrictions(restrictions) logger.debug("policy restrictions: %s", self._restrictions) - self.acs = [] + self.acs: list[AttributeConverter] = [] - def setup_restrictions(self, restrictions=None): + def setup_restrictions(self, restrictions: Optional[Mapping[str, Any]] = None) -> Optional[PolicyConfig]: if restrictions is None: return None restrictions = copy.deepcopy(restrictions) - restrictions = compile(restrictions) + restrictions = self._compile_restrictions(restrictions) return restrictions - def get(self, attribute, sp_entity_id, default=None): + @staticmethod + def _compile_restrictions(restrictions: Mapping[str, Any]) -> PolicyConfig: + """ + Pre-compile regular expressions in rules in `restrictions'. + + This is only for IdPs or AAs, and it's about limiting what + is returned to the SP. + In the configuration file, restrictions on which values that + can be returned are specified with the help of regular expressions. + This function goes through and pre-compiles the regular expressions. + + :param restrictions: policy configuration + :return: The assertion with the string specification replaced with + a compiled regular expression. + """ + config: PolicyConfig = {} + for who, spec in restrictions.items(): + if spec is None: + spec = {} + + entity_categories: list[str] = spec.get("entity_categories", []) + _new_entity_categories = EntityCategoryPolicy.from_module_names(entity_categories) + + attribute_restrictions: Mapping[str, list[str]] = spec.get("attribute_restrictions") or {} + _attribute_restrictions: AttributeRestrictions = {} + for key, values in attribute_restrictions.items(): + lkey = key.lower() + values = [] if not values else values + _attribute_restrictions[lkey] = [re.compile(value) for value in values] or None + _new_attribute_restrictions = _attribute_restrictions or None + + config[who] = PolicyConfigValue( + lifetime=spec.get("lifetime"), + attribute_restrictions=_new_attribute_restrictions, + name_form=spec.get("name_form"), + nameid_format=spec.get("nameid_format"), + entity_categories=_new_entity_categories, + sign=spec.get("sign"), + fail_on_missing_requested=spec.get("fail_on_missing_requested"), + ) + + return config + + def get(self, attribute: str, sp_entity_id: str, default: Any = None) -> Any: """ :param attribute: @@ -334,42 +583,45 @@ def get(self, attribute, sp_entity_id, default=None): if not self._restrictions: return default - ra_info = self.metadata_store.registration_info(sp_entity_id) or {} if self.metadata_store is not None else {} - ra_entity_id = ra_info.get("registration_authority") + ra_info: Mapping[str, Any] = {} + if self.metadata_store is not None: + ra_info = self.metadata_store.registration_info(sp_entity_id) or {} + ra_entity_id: str = ra_info.get("registration_authority") # type: ignore[assignment] sp_restrictions = self._restrictions.get(sp_entity_id) ra_restrictions = self._restrictions.get(ra_entity_id) default_restrictions = self._restrictions.get("default") or self._restrictions.get("") - restrictions = ( + restrictions: Optional[PolicyConfigValue] = ( sp_restrictions if sp_restrictions is not None else ra_restrictions if ra_restrictions is not None else default_restrictions if default_restrictions is not None - else {} + else None ) - attribute_restriction = restrictions.get(attribute) - restriction = attribute_restriction if attribute_restriction is not None else default - return restriction + attribute_restriction = getattr(restrictions, attribute, None) + if attribute_restriction is None: + return default + return attribute_restriction - def get_nameid_format(self, sp_entity_id): + def get_nameid_format(self, sp_entity_id: str): """Get the NameIDFormat to used for the entity id :param: The SP entity ID - :retur: The format + :return: The format """ return self.get("nameid_format", sp_entity_id, saml.NAMEID_FORMAT_TRANSIENT) - def get_name_form(self, sp_entity_id): + def get_name_form(self, sp_entity_id: str): """Get the NameFormat to used for the entity id :param: The SP entity ID - :retur: The format + :return: The format """ return self.get("name_form", sp_entity_id, default=NAME_FORMAT_URI) - def get_lifetime(self, sp_entity_id): + def get_lifetime(self, sp_entity_id: str): """The lifetime of the assertion :param sp_entity_id: The SP entity ID :param: lifetime as a dictionary @@ -377,7 +629,7 @@ def get_lifetime(self, sp_entity_id): # default is a hour return self.get("lifetime", sp_entity_id, {"hours": 1}) - def get_attribute_restrictions(self, sp_entity_id): + def get_attribute_restrictions(self, sp_entity_id: str) -> Optional[AttributeRestrictions]: """Return the attribute restriction for SP that want the information :param sp_entity_id: The SP entity ID @@ -386,7 +638,7 @@ def get_attribute_restrictions(self, sp_entity_id): return self.get("attribute_restrictions", sp_entity_id) - def get_fail_on_missing_requested(self, sp_entity_id): + def get_fail_on_missing_requested(self, sp_entity_id: str): """Return the whether the IdP should should fail if the SPs requested attributes could not be found. @@ -396,7 +648,7 @@ def get_fail_on_missing_requested(self, sp_entity_id): return self.get("fail_on_missing_requested", sp_entity_id, default=True) - def get_sign(self, sp_entity_id): + def get_sign(self, sp_entity_id: str): """ Possible choices "sign": ["response", "assertion", "on_demand"] @@ -407,7 +659,9 @@ def get_sign(self, sp_entity_id): return self.get("sign", sp_entity_id, default=[]) - def get_entity_categories(self, sp_entity_id, mds=None, required=None): + def _get_restrictions_for_entity_categories( + self, sp_entity_id: str, mds: Optional[MetadataStore] = None, required: Optional[List[AttributeAsDict]] = None + ) -> AttributeRestrictions: """ :param sp_entity_id: @@ -424,61 +678,20 @@ def get_entity_categories(self, sp_entity_id, mds=None, required=None): logger.warning(warn_msg) _warn(warn_msg, DeprecationWarning) - def post_entity_categories(maps, sp_entity_id=None, mds=None, required=None): - restrictions = {} - required_friendly_names = [ - d.get("friendly_name") or get_local_name(acs=self.acs, attr=d["name"], name_format=d["name_format"]) - for d in (required or []) - ] - required = [friendly_name.lower() for friendly_name in required_friendly_names] - - if mds: - ecs = mds.entity_categories(sp_entity_id) - for ec_map in maps: - for key, (atlist, only_required, no_aggregation) in ec_map.items(): - if key == "": # always released - attrs = atlist - elif isinstance(key, tuple): - if only_required: - attrs = [a for a in atlist if a in required] - else: - attrs = atlist - for _key in key: - if _key not in ecs: - attrs = [] - break - elif key in ecs: - if only_required: - attrs = [a for a in atlist if a in required] - else: - attrs = atlist - else: - attrs = [] - - if attrs and no_aggregation: - # clear restrictions if the found category is a no aggregation category - restrictions = {} - for attr in attrs: - restrictions[attr] = None - else: - restrictions[""] = None - - return restrictions - - sentinel = object() - result1 = self.get("entity_categories", sp_entity_id, default=sentinel) - if result1 is sentinel: + result1: Optional[EntityCategoryPolicy] = self.get("entity_categories", sp_entity_id) + if result1 is None or not result1.categories: return {} - result2 = post_entity_categories( - result1, + assert isinstance(result1, EntityCategoryPolicy) + + return result1.attribute_restrictions_for_sp( + acs=self.acs, sp_entity_id=sp_entity_id, mds=(mds or self.metadata_store), required=required, ) - return result2 - def not_on_or_after(self, sp_entity_id): + def not_on_or_after(self, sp_entity_id: str): """When the assertion stops being valid, should not be used after this time. @@ -488,7 +701,14 @@ def not_on_or_after(self, sp_entity_id): return in_a_while(**self.get_lifetime(sp_entity_id)) - def filter(self, ava, sp_entity_id, mdstore=None, required=None, optional=None): + def filter( + self, + ava: AttributeValues, + sp_entity_id: str, + mdstore: Optional[MetadataStore] = None, + required: Optional[list[AttributeAsDict]] = None, + optional: Optional[list[AttributeAsDict]] = None, + ) -> AttributeValues: """What attribute and attribute values returns depends on what the SP or the registration authority has said it wants in the request or in the metadata file and what the IdP/AA wants to release. @@ -519,7 +739,7 @@ def filter(self, ava, sp_entity_id, mdstore=None, required=None, optional=None): subject_ava = ava.copy() # entity category restrictions - _ent_rest = self.get_entity_categories(sp_entity_id, mds=mdstore, required=required) + _ent_rest = self._get_restrictions_for_entity_categories(sp_entity_id, mds=mdstore, required=required) if _ent_rest: subject_ava = filter_attribute_value_assertions(subject_ava, _ent_rest) elif required or optional: @@ -538,7 +758,7 @@ def filter(self, ava, sp_entity_id, mdstore=None, required=None, optional=None): return subject_ava or {} - def restrict(self, ava, sp_entity_id, metadata=None): + def restrict(self, ava: AttributeValues, sp_entity_id: str, metadata: Optional[MetadataStore] = None): """Identity attribute names are expected to be expressed as FriendlyNames :return: A filtered ava according to the IdPs/AAs rules and @@ -853,3 +1073,8 @@ def apply_policy(self, sp_entity_id, policy): del self[key] return ava + + +def compile(restrictions: Mapping[str, Any]) -> PolicyConfig: + _warn("compile() is believe to be unused as an exported function and will be removed, use Policy() instead") + return Policy._compile_restrictions(restrictions) diff --git a/src/saml2/entity_category/swamid.py b/src/saml2/entity_category/swamid.py index 79bb4ed66..6d59ca8fc 100644 --- a/src/saml2/entity_category/swamid.py +++ b/src/saml2/entity_category/swamid.py @@ -100,8 +100,6 @@ HEI = "http://www.swamid.se/category/hei-service" # Deprecated from 2021-03-31 RELEASE = { - # NOTICE: order is important - # no-aggregation categories need to come last and in order of least to most restrictive "": [], SFS_1993_1153: ["norEduPersonNIN", "eduPersonAssurance"], (RESEARCH_AND_EDUCATION, EU): NAME + STATIC_ORG_INFO + OTHER, @@ -113,12 +111,6 @@ ESI: MYACADEMICID_ESI, (ESI, COCOv1): MYACADEMICID_ESI + GEANT_COCO, (ESI, COCOv2): MYACADEMICID_ESI + REFEDS_COCO, - # XXX: disabled temporarily until we can figure out how to handle them - # these need to be able to be combined with other categories just not with each other - # no aggregation categories - # PERSONALIZED: REFEDS_PERSONALIZED_ACCESS, - # PSEUDONYMOUS: REFEDS_PSEUDONYMOUS_ACCESS, - # ANONYMOUS: REFEDS_ANONYMOUS_ACCESS, } ONLY_REQUIRED = { @@ -128,8 +120,42 @@ (ESI, COCOv2): True, } -NO_AGGREGATION = { - PERSONALIZED: True, - PSEUDONYMOUS: True, - ANONYMOUS: True, -} +# These restrictions are parsed (and validated) into a list of saml2.assertion.EntityCategoryRule instances. +RESTRICTIONS = [ + { + "match": { + "required": [PERSONALIZED], + "conflicts": [PSEUDONYMOUS, ANONYMOUS], + }, + "attributes": REFEDS_PERSONALIZED_ACCESS, + }, + { + "match": { + "required": [PSEUDONYMOUS], + "conflicts": [ANONYMOUS], + }, + "attributes": REFEDS_PSEUDONYMOUS_ACCESS, + }, + { + "match": { + "required": [ANONYMOUS], + }, + "attributes": REFEDS_ANONYMOUS_ACCESS, + }, + # Example of conversion of some of the rules in RELEASE to this new format: + # + # { + # "match": { + # "required": [COCOv1], + # }, + # "attributes": GEANT_COCO, + # "only_required": True, + # }, + # { + # "match": { + # "required": [ESI, COCOv1], + # }, + # "attributes": MYACADEMICID_ESI + GEANT_COCO, + # "only_required": True, + # }, +] diff --git a/src/saml2/s_utils.py b/src/saml2/s_utils.py index 04a19c9ec..f27108d6e 100644 --- a/src/saml2/s_utils.py +++ b/src/saml2/s_utils.py @@ -8,6 +8,7 @@ import string import sys import traceback +from typing import Union import zlib from saml2 import VERSION @@ -327,7 +328,7 @@ def do_ava(val, typ=""): return attrval -def do_attribute(val, typ, key): +def do_attribute(val, typ, key: Union[str, tuple]) -> saml.Attribute: attr = saml.Attribute() attrval = do_ava(val, typ) if attrval: diff --git a/src/saml2/typing.py b/src/saml2/typing.py new file mode 100644 index 000000000..18922877f --- /dev/null +++ b/src/saml2/typing.py @@ -0,0 +1,30 @@ +# Type information for common pysaml2 data types, often found in configuration etc. +# + +from typing import Literal +from typing import Mapping +from typing import Optional +from typing import TypedDict +from typing import Union + + +# Required attributes are specified as dicts, e.g.: +# +# { +# "friendly_name": "eduPersonScopedAffiliation", +# "name": "1.3.6.1.4.1.5923.1.1.1.9", +# "name_format": NAME_FORMAT_URI, +# "is_required": "true", +# "attribute_value": [{"text": Any, ...}] +# } +class AttributeAsDict(TypedDict): + friendly_name: Optional[str] + name: str + name_format: str + is_required: Union[Literal["true"], Literal["false"]] + attribute_value: Optional[list[Mapping[str, str]]] + + +# Type for the common 'ava' parameter. +AttributeValues = dict[str, Union[list[str], str]] +AttributeValuesStrict = dict[str, list[str]] diff --git a/tests/entity_personalized_sp.xml b/tests/entity_personalized_sp.xml index aa48693a4..64bd98f00 100644 --- a/tests/entity_personalized_sp.xml +++ b/tests/entity_personalized_sp.xml @@ -5,6 +5,7 @@ https://refeds.org/category/personalized + https://refeds.org/category/code-of-conduct/v2 @@ -67,6 +68,7 @@ wHyaxzYldWmVC5omkgZeAdCGpJ316GQF8Zwg/yDOUzm4cvGeIESf1Q6ZxBwI6zGE personalized-SP refeds personalized access SP + diff --git a/tests/test_20_assertion.py b/tests/test_20_assertion.py index c40bde4bb..e713de37a 100644 --- a/tests/test_20_assertion.py +++ b/tests/test_20_assertion.py @@ -33,6 +33,7 @@ from saml2.saml import AttributeValue from saml2.saml import Issuer from saml2.saml import NameID +from saml2.typing import AttributeValues ONTS = [saml, mdui, mdattr, dri, idpdisc, md, xmldsig, xmlenc] @@ -181,6 +182,29 @@ def test_filter_on_attributes_with_missing_name_format(): assert ava["eduPersonTargetedID"] == "test@example.com" +def test_filter_on_attributes_with_not_allowed_value(): + + a = to_dict( + Attribute( + friendly_name="surName", + name="urn:oid:2.5.4.4", + name_format=NAME_FORMAT_URI, + attribute_value=[{"text": "A value"}], + ), + ONTS, + ) + attributes = [a] + ava = {"sn": ["Not the allowed value"]} + + ava = filter_on_attributes(ava, optional=attributes, acs=ac_factory()) + assert list(ava.keys()) == [] + + ava = {"sn": ["Not the allowed value"]} + + with raises(MissingValue): + filter_on_attributes(ava, required=attributes, acs=ac_factory()) + + # ---------------------------------------------------------------------- @@ -301,7 +325,7 @@ def test_ava_filter_dont_fail(): policy = Policy(conf) - ava = {"givenName": "Derek", "surName": "Jeter", "mail": "derek@example.com"} + ava = AttributeValues({"givenName": "Derek", "surName": "Jeter", "mail": "derek@example.com"}) # mail removed because it doesn't match the regular expression # So it should fail if the 'fail_on_ ...' flag wasn't set @@ -309,7 +333,7 @@ def test_ava_filter_dont_fail(): assert _ava - ava = {"givenName": "Derek", "surName": "Jeter"} + ava = AttributeValues({"givenName": "Derek", "surName": "Jeter"}) # it wasn't there to begin with _ava = policy.filter(ava, "urn:mace:umu.se:saml:roland:sp", required=[gn, sn, mail]) diff --git a/tests/test_31_config.py b/tests/test_31_config.py index c7741c466..6bd169e06 100644 --- a/tests/test_31_config.py +++ b/tests/test_31_config.py @@ -6,6 +6,7 @@ from saml2 import BINDING_HTTP_REDIRECT from saml2 import BINDING_SOAP +from saml2.assertion import EntityCategoryPolicy from saml2.authn_context import PASSWORDPROTECTEDTRANSPORT as AUTHN_PASSWORD_PROTECTED from saml2.authn_context import TIMESYNCTOKEN as AUTHN_TIME_SYNC_TOKEN from saml2.config import Config @@ -202,6 +203,23 @@ "crypto_backend": "XMLSecurity", } +IDP_SWAMID = { + "entityid": "urn:mace:umu.se:saml:sunet:idp", + "name": "Swamid entity categories restrictions in an IdP", + "service": { + "idp": { + "endpoints": { + "single_sign_on_service": ["http://localhost:8088/"], + }, + "policy": { + "default": { + "entity_categories": ["swamid", "edugain"], + }, + }, + } + }, +} + def _eq(l1, l2): return set(l1) == set(l2) @@ -388,5 +406,24 @@ def test_set_force_authn(): assert bool(cnf.getattr("force_authn", "sp")) == True +def test_idp_loading_entity_categories_restrictions(): + """Test loading an IdP config with two entity categories in the policy and make sure they are loaded correctly""" + c = IdPConfig().load(IDP_SWAMID) + c.context = "idp" + + ec_policy: EntityCategoryPolicy = c.getattr("policy", "idp").get("entity_categories", "sp entity id") + print(ec_policy) + + assert list(ec_policy.categories.keys()) == ["swamid", "edugain"] + + # Make sure all the categories have at least two rules (edugain currently has two) + for rules in ec_policy.categories.values(): + assert len(rules) >= 2 + + # Look for a specific rule in the swamid category + swamid_rules = ec_policy.categories["swamid"] + assert any(rule.match.required == ["https://refeds.org/category/personalized"] for rule in swamid_rules) + + if __name__ == "__main__": test_crypto_backend() diff --git a/tests/test_37_entity_categories.py b/tests/test_37_entity_categories.py index 894b03cf3..3372309ed 100644 --- a/tests/test_37_entity_categories.py +++ b/tests/test_37_entity_categories.py @@ -14,6 +14,7 @@ from saml2.mdstore import MetadataStore from saml2.saml import NAME_FORMAT_URI from saml2.server import Server +from saml2.typing import AttributeAsDict ATTRCONV = ac_factory(full_path("attributemaps")) @@ -293,7 +294,6 @@ def test_filter_ava_esi_coco(): ) -@pytest.mark.skip("Temporarily disabled") def test_filter_ava_refeds_anonymous_access(): entity_id = "https://anonymous.example.edu/saml2/metadata/" mds = MetadataStore(ATTRCONV, sec_config, disable_ssl_certificate_validation=True) @@ -322,7 +322,6 @@ def test_filter_ava_refeds_anonymous_access(): assert _eq(ava["schacHomeOrganization"], ["example.com"]) -@pytest.mark.skip("Temporarily disabled") def test_filter_ava_refeds_pseudonymous_access(): entity_id = "https://pseudonymous.example.edu/saml2/metadata/" mds = MetadataStore(ATTRCONV, sec_config, disable_ssl_certificate_validation=True) @@ -355,7 +354,6 @@ def test_filter_ava_refeds_pseudonymous_access(): assert _eq(ava["schacHomeOrganization"], ["example.com"]) -@pytest.mark.skip("Temporarily disabled") def test_filter_ava_refeds_personalized_access(): entity_id = "https://personalized.example.edu/saml2/metadata/" mds = MetadataStore(ATTRCONV, sec_config, disable_ssl_certificate_validation=True) @@ -377,7 +375,11 @@ def test_filter_ava_refeds_personalized_access(): "subject-id": ["subject-id@example.com"], } - ava = policy.filter(ava, entity_id) + attribute_requirements = mds.attribute_requirement(entity_id) + required = attribute_requirements.get("required", []) + optional = attribute_requirements.get("optional", []) + + ava = policy.filter(ava, entity_id, required=required, optional=optional) assert _eq( list(ava.keys()), @@ -390,6 +392,7 @@ def test_filter_ava_refeds_personalized_access(): "eduPersonScopedAffiliation", "eduPersonAssurance", "schacHomeOrganization", + "eduPersonTargetedID", ], ) assert _eq(ava["subject-id"], ["subject-id@example.com"])