diff --git a/synapse_token_authenticator/auth_headers.py b/synapse_token_authenticator/auth_headers.py new file mode 100644 index 0000000..3bb2355 --- /dev/null +++ b/synapse_token_authenticator/auth_headers.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from base64 import b64encode +from dataclasses import dataclass +from typing import Protocol + + +class HttpAuth(Protocol): + def header_map(self) -> dict[bytes, list[bytes]]: + """Retrieve the mapping for the authorization for Header generation""" + ... + + +@dataclass +class NoAuth: + def header_map(self) -> dict[bytes, list[bytes]]: + return {} + + +@dataclass +class BasicAuth: + username: str + password: str + + def header_map(self) -> dict[bytes, list[bytes]]: + return basic_auth(self.username, self.password) + + +@dataclass +class BearerAuth: + token: str + + def header_map(self) -> dict[bytes, list[bytes]]: + return bearer_auth(self.token) + + +def parse_auth(d: dict | list) -> HttpAuth: + if isinstance(d, dict): + _type = d.pop("type") + if _type is None: + return NoAuth() + elif _type == "basic": + return BasicAuth(**d) + elif _type == "bearer": + return BearerAuth(**d) + else: + raise Exception(f"Unknown HttpAuth type {_type}") + elif isinstance(d, list): + _type = d.pop(0) + if _type is None: + return NoAuth() + elif _type == "basic": + return BasicAuth(*d) + elif _type == "bearer": + return BearerAuth(*d) + else: + raise Exception(f"Unknown HttpAuth type {_type}") + else: + raise Exception("HttpAuth parsing failed, expected list or dict") + + +def basic_auth(username: str, password: str) -> dict[bytes, list[bytes]]: + authorization = b64encode( + b":".join((username.encode("utf8"), password.encode("utf8"))) + ) + return {b"Authorization": [b"Basic " + authorization]} + + +def bearer_auth(token: str) -> dict[bytes, list[bytes]]: + return {b"Authorization": [b"Bearer " + token.encode("utf8")]} diff --git a/synapse_token_authenticator/claims_validator.py b/synapse_token_authenticator/claims_validator.py index 5553e11..8970db2 100644 --- a/synapse_token_authenticator/claims_validator.py +++ b/synapse_token_authenticator/claims_validator.py @@ -35,7 +35,7 @@ ] -def parse_validator(d: dict) -> Validator: +def parse_validator(d: dict | list) -> Validator: if isinstance(d, dict): type = d.pop("type") if type == "exist": @@ -94,7 +94,7 @@ def validate(self, x: Any) -> bool: class Not: validator: Validator - def __post_init__(self): + def __post_init__(self) -> None: self.validator = parse_validator(self.validator) def validate(self, x: Any) -> bool: @@ -114,7 +114,7 @@ class MatchesRegex: regex: str full_match: bool | None = True - def __post_init__(self): + def __post_init__(self) -> None: self.regex_prog = re.compile(self.regex) def validate(self, s: Any) -> bool: @@ -130,7 +130,7 @@ def validate(self, s: Any) -> bool: class AnyOf: validators: List[Validator] - def __post_init__(self): + def __post_init__(self) -> None: self.validators = list(map(lambda v: parse_validator(v), self.validators)) def validate(self, x: Any) -> bool: @@ -141,7 +141,7 @@ def validate(self, x: Any) -> bool: class AllOf: validators: List[Validator] - def __post_init__(self): + def __post_init__(self) -> None: self.validators = list(map(lambda v: parse_validator(v), self.validators)) def validate(self, x: Any) -> bool: @@ -153,7 +153,7 @@ class In: path: str | List[str] validator: Optional[Validator] = None - def __post_init__(self): + def __post_init__(self) -> None: if not self.path: raise Exception("Path list is empty") if self.validator: @@ -172,7 +172,7 @@ def validate(self, x: Any) -> bool: class ListAllOf: validator: Validator - def __post_init__(self): + def __post_init__(self) -> None: if self.validator: self.validator = parse_validator(self.validator) @@ -186,7 +186,7 @@ def validate(self, list_: Any) -> bool: class ListAnyOf: validator: Validator - def __post_init__(self): + def __post_init__(self) -> None: if self.validator: self.validator = parse_validator(self.validator) diff --git a/synapse_token_authenticator/config.py b/synapse_token_authenticator/config.py index 1cadbc9..75ca00e 100644 --- a/synapse_token_authenticator/config.py +++ b/synapse_token_authenticator/config.py @@ -3,206 +3,210 @@ from typing import Any, List, Literal, TypeAlias, Union from jwcrypto.jwk import JWK, JWKSet +from synapse.types import JsonDict +from synapse_token_authenticator.auth_headers import HttpAuth, NoAuth, parse_auth from synapse_token_authenticator.claims_validator import ( Exist, Validator, parse_validator, ) -from synapse_token_authenticator.utils import basic_auth, bearer_auth +Path: TypeAlias = Union[str, List[str]] +PathList: TypeAlias = Union[Path, List[List[str]]] -class TokenAuthenticatorConfig: - """ - Parses and validates the provided config dictionary. - """ - def __init__(self, other: dict): - if jwt := other.get("jwt"): +class JwtConfig: + def __init__(self, other: dict) -> None: + self.secret: str | None = other.get("secret") + self.keyfile: str | None = other.get("keyfile") - class JwtConfig: - def __init__(self, other: dict): - self.secret: str | None = other.get("secret") - self.keyfile: str | None = other.get("keyfile") + self.algorithm: str = other.get("algorithm", "HS512") + self.allow_registration: bool = other.get("allow_registration", False) + self.require_expiry: bool = other.get("require_expiry", True) - self.algorithm: str = other.get("algorithm", "HS512") - self.allow_registration: bool = other.get( - "allow_registration", False - ) - self.require_expiry: bool = other.get("require_expiry", True) - self.jwt = JwtConfig(jwt) - verify_jwt_based_cfg(self.jwt) +class OIDCConfig: + def __init__(self, other: dict) -> None: + try: + self.issuer: str = other["issuer"] + self.client_id: str = other["client_id"] + self.client_secret: str = other["client_secret"] + self.project_id: str = other["project_id"] + self.organization_id: str = other["organization_id"] + except KeyError as error: + raise Exception(f"Config option must be set: {error.args[0]}") - if oidc := other.get("oidc"): + self.allowed_client_ids: str | None = other.get("allowed_client_ids") - class OIDCConfig: - def __init__(self, other: dict): - try: - self.issuer: str = other["issuer"] - self.client_id: str = other["client_id"] - self.client_secret: str = other["client_secret"] - self.project_id: str = other["project_id"] - self.organization_id: str = other["organization_id"] - except KeyError as error: - raise Exception(f"Config option must be set: {error.args[0]}") + self.allow_registration: bool = other.get("allow_registration", False) - self.allowed_client_ids: str | None = other.get( - "allowed_client_ids" - ) - self.allow_registration: bool = other.get( - "allow_registration", False - ) +@dataclass +class JwtValidationConfig: + validator: Validator = field(default_factory=Exist) + require_expiry: bool = False + localpart_path: Path | None = None + user_id_path: Path | None = None + fq_uid_path: Path | None = None + displayname_path: Path | None = None + admin_path: PathList | None = None + email_path: Path | None = None + required_scopes: str | List[str] | None = None + jwk_set: JWKSet | JWK | None = None + jwk_file: str | None = None + jwks_endpoint: str | None = None + + def __post_init__(self) -> None: + if not isinstance(self.validator, Exist): + self.validator = parse_validator(self.validator) + + if self.jwk_set and ("keys" in self.jwk_set): + self.jwk_set = JWKSet(**self.jwk_set) + elif self.jwk_set: + self.jwk_set = JWK(**self.jwk_set) + elif self.jwk_file: + with open(self.jwk_file) as f: + self.jwk_set = JWK.from_pem(f.read()) + elif not self.jwks_endpoint: + raise Exception("No JWK") - self.oidc = OIDCConfig(oidc) - if config := other.get("oauth"): - - Path: TypeAlias = Union[str, List[str]] - PathList: TypeAlias = Union[Path, List[List[str]]] - - @dataclass - class JwtValidationConfig: - validator: Validator = field(default_factory=Exist) - require_expiry: bool = False - localpart_path: Path | None = None - user_id_path: Path | None = None - fq_uid_path: Path | None = None - displayname_path: Path | None = None - admin_path: PathList | None = None - email_path: Path | None = None - required_scopes: str | List[str] | None = None - jwk_set: JWKSet | JWK | None = None - jwk_file: str | None = None - jwks_endpoint: str | None = None - - def __post_init__(self): - if not isinstance(self.validator, Exist): - self.validator = parse_validator(self.validator) - - if self.jwk_set and ("keys" in self.jwk_set): - self.jwk_set = JWKSet(**self.jwk_set) - elif self.jwk_set: - self.jwk_set = JWK(**self.jwk_set) - elif self.jwk_file: - with open(self.jwk_file) as f: - self.jwk_set = JWK.from_pem(f.read()) - elif not self.jwks_endpoint: - raise Exception("No JWK") - - @dataclass - class IntrospectionValidationConfig: - endpoint: str - validator: Validator = field(default_factory=Exist) - auth: HttpAuth = field(default_factory=NoAuth) - localpart_path: Path | None = None - user_id_path: Path | None = None - fq_uid_path: Path | None = None - displayname_path: Path | None = None - admin_path: PathList | None = None - email_path: Path | None = None - required_scopes: str | List[str] | None = None - - def __post_init__(self): - if not isinstance(self.validator, Exist): - self.validator = parse_validator(self.validator) - - if not isinstance(self.auth, NoAuth): - self.auth = parse_auth(self.auth) - - @dataclass - class NotifyOnRegistration: - url: str - auth: HttpAuth = field(default_factory=NoAuth) - interrupt_on_error: bool = True - - def __post_init__(self): - if not isinstance(self.auth, NoAuth): - self.auth = parse_auth(self.auth) - - @dataclass - class OAuthConfig: - jwt_validation: JwtValidationConfig | None = None - introspection_validation: IntrospectionValidationConfig | None = None - username_type: Literal["fq_uid", "localpart", "user_id"] | None = None - notify_on_registration: NotifyOnRegistration | None = None - expose_metadata_resource: Any = None - registration_enabled: bool = False - check_external_id: bool = True - - def __post_init__(self): - if self.notify_on_registration: - self.notify_on_registration = NotifyOnRegistration( - **self.notify_on_registration - ) - if self.jwt_validation: - self.jwt_validation = JwtValidationConfig( - **(self.jwt_validation) - ) - if self.introspection_validation: - self.introspection_validation = IntrospectionValidationConfig( - **self.introspection_validation - ) - if not (self.jwt_validation or self.introspection_validation): - raise Exception( - "Neither jwt_validation nor introspection_validation was specified" - ) - if self.username_type not in [ - "fq_uid", - "localpart", - "user_id", - None, - ]: - raise Exception(f"Unknown username_type {self.username_type}") +@dataclass +class IntrospectionValidationConfig: + endpoint: str + validator: Validator = field(default_factory=Exist) + auth: HttpAuth = field(default_factory=NoAuth) + localpart_path: Path | None = None + user_id_path: Path | None = None + fq_uid_path: Path | None = None + displayname_path: Path | None = None + admin_path: PathList | None = None + email_path: Path | None = None + required_scopes: str | List[str] | None = None + + def __post_init__(self) -> None: + if not isinstance(self.validator, Exist): + self.validator = parse_validator(self.validator) + + if not isinstance(self.auth, NoAuth): + self.auth = parse_auth(self.auth) - self.oauth = OAuthConfig(**config) - if epa := other.get("epa"): - - @dataclass - class EPaConfig: - iss: str - resource_id: str - validator: Validator = field(default_factory=Exist) - expose_metadata_resource: Any = None - registration_enabled: bool = False - enc_jwk: JWK | None = None - enc_jwk_file: str | None = None - enc_jwks_endpoint: str = "/.well-known/jwks.json" - jwk_set: JWKSet | JWK | None = None - jwk_file: str | None = None - jwks_endpoint: str | None = None - localpart_path: str | None = None - displayname_path: str | None = None - lowercase_localpart: bool = False - - def __post_init__(self): - if not isinstance(self.validator, Exist): - self.validator = parse_validator(self.validator) - - if self.enc_jwk: - self.enc_jwk = JWK(**self.enc_jwk) - elif self.enc_jwk_file: - with open(self.enc_jwk_file) as f: - self.enc_jwk = JWK.from_pem(f.read()) - else: - raise Exception("No encryption JWK") - - if self.jwk_set and ("keys" in self.jwk_set): - self.jwk_set = JWKSet(**self.jwk_set) - elif self.jwk_set: - self.jwk_set = JWK(**self.jwk_set) - elif self.jwk_file: - with open(self.jwk_file) as f: - self.jwk_set = JWK.from_pem(f.read()) - elif not self.jwks_endpoint: - raise Exception("No JWK") +@dataclass +class NotifyOnRegistration: + url: str + auth: HttpAuth = field(default_factory=NoAuth) + interrupt_on_error: bool = True + + def __post_init__(self) -> None: + if not isinstance(self.auth, NoAuth): + self.auth = parse_auth(self.auth) + + +@dataclass +class OAuthConfig: + jwt_validation: JwtValidationConfig | None = None + introspection_validation: IntrospectionValidationConfig | None = None + username_type: Literal["fq_uid", "localpart", "user_id"] | None = None + notify_on_registration: NotifyOnRegistration | None = None + expose_metadata_resource: Any = None + registration_enabled: bool = False + check_external_id: bool = True + + def __post_init__(self) -> None: + if self.notify_on_registration: + self.notify_on_registration = NotifyOnRegistration( + **self.notify_on_registration + ) + if self.jwt_validation: + self.jwt_validation = JwtValidationConfig(**(self.jwt_validation)) + if self.introspection_validation: + self.introspection_validation = IntrospectionValidationConfig( + **self.introspection_validation + ) + if not (self.jwt_validation or self.introspection_validation): + raise Exception( + "Neither jwt_validation nor introspection_validation was specified" + ) + if self.username_type not in [ + "fq_uid", + "localpart", + "user_id", + None, + ]: + raise Exception(f"Unknown username_type {self.username_type}") + + +@dataclass +class EPaConfig: + iss: str + resource_id: str + validator: Validator = field(default_factory=Exist) + expose_metadata_resource: Any = None + registration_enabled: bool = False + enc_jwk: JWK | None = None + enc_jwk_file: str | None = None + enc_jwks_endpoint: str = "/.well-known/jwks.json" + jwk_set: JWKSet | JWK | None = None + jwk_file: str | None = None + jwks_endpoint: str | None = None + localpart_path: str | None = None + displayname_path: str | None = None + lowercase_localpart: bool = False + + def __post_init__(self) -> None: + if not isinstance(self.validator, Exist): + self.validator = parse_validator(self.validator) + + if self.enc_jwk: + self.enc_jwk = JWK(**self.enc_jwk) + elif self.enc_jwk_file: + with open(self.enc_jwk_file) as f: + self.enc_jwk = JWK.from_pem(f.read()) + else: + raise Exception("No encryption JWK") + + if self.jwk_set and ("keys" in self.jwk_set): + self.jwk_set = JWKSet(**self.jwk_set) + elif self.jwk_set: + self.jwk_set = JWK(**self.jwk_set) + elif self.jwk_file: + with open(self.jwk_file) as f: + self.jwk_set = JWK.from_pem(f.read()) + elif not self.jwks_endpoint: + raise Exception("No JWK") + + +class TokenAuthenticatorConfig: + """ + Parses and validates the provided config dictionary. + """ + jwt: JwtConfig | None = None + oidc: OIDCConfig | None = None + oauth: OAuthConfig | None = None + epa: EPaConfig | None = None + + def __init__(self, other: JsonDict) -> None: + # Walrus operators judge the value as truthy/falsey, not as strictly + # None/not-None. An empty container(like a dict) is considered falsey. + if jwt := other.get("jwt", {}): + self.jwt = JwtConfig(jwt) + if self.jwt: + verify_jwt_based_cfg(self.jwt) + + if oidc := other.get("oidc", {}): + self.oidc = OIDCConfig(oidc) + + if config := other.get("oauth", {}): + self.oauth = OAuthConfig(**config) + + if epa := other.get("epa", {}): self.epa = EPaConfig(**epa) -def verify_jwt_based_cfg(cfg): +def verify_jwt_based_cfg(cfg: JwtConfig) -> None: if cfg.secret is None and cfg.keyfile is None: raise Exception("Missing secret or keyfile") if cfg.keyfile is not None and not os.path.exists(cfg.keyfile): @@ -224,54 +228,3 @@ def verify_jwt_based_cfg(cfg): "EdDSA", ]: raise Exception(f"Unknown algorithm: '{cfg.algorithm}'") - - -@dataclass -class NoAuth: - def header_map(self): - return {} - - -@dataclass -class BasicAuth: - username: str - password: str - - def header_map(self): - return basic_auth(self.username, self.password) - - -@dataclass -class BearerAuth: - token: str - - def header_map(self): - return bearer_auth(self.token) - - -HttpAuth: TypeAlias = Union[BasicAuth, BearerAuth, NoAuth] - - -def parse_auth(d: dict) -> HttpAuth: - if isinstance(d, dict): - type = d.pop("type") - if type is None: - return NoAuth() - elif type == "basic": - return BasicAuth(**d) - elif type == "bearer": - return BearerAuth(**d) - else: - raise Exception(f"Unknown HttpAuth type {type}") - elif isinstance(d, list): - type = d.pop(0) - if type is None: - return NoAuth() - elif type == "basic": - return BasicAuth(*d) - elif type == "bearer": - return BearerAuth(*d) - else: - raise Exception(f"Unknown HttpAuth type {type}") - else: - raise Exception("HttpAuth parsing failed, expected list or dict") diff --git a/synapse_token_authenticator/rest.py b/synapse_token_authenticator/rest.py new file mode 100644 index 0000000..3880f72 --- /dev/null +++ b/synapse_token_authenticator/rest.py @@ -0,0 +1,52 @@ +import json +from urllib.parse import urljoin + +from jwcrypto.jwk import JWKSet +from twisted.web import resource + +from synapse_token_authenticator.config import OIDCConfig + + +class LoginMetadataResource(resource.Resource): + def __init__(self, oidc_config: OIDCConfig): + super().__init__() + self.issuer = oidc_config.issuer + self.metadata_url = urljoin( + oidc_config.issuer, "/.well-known/openid-configuration" + ) + self.organization_id = oidc_config.organization_id + self.project_id = oidc_config.project_id + + def render_GET(self, request) -> bytes: + request.setHeader(b"content-type", b"application/json") + request.setHeader(b"access-control-allow-origin", b"*") + return json.dumps( + { + "issuer": self.issuer, + "issuer-metadata": self.metadata_url, + "organization-id": self.organization_id, + "project-id": self.project_id, + } + ).encode("utf-8") + + +class PublicKeysResource(resource.Resource): + def __init__(self, keys: JWKSet): + super().__init__() + self.keys = keys.export(private_keys=False).encode("utf-8") + + def render_GET(self, request) -> bytes: + request.setHeader(b"content-type", b"application/json") + request.setHeader(b"access-control-allow-origin", b"*") + return self.keys + + +class MetadataResource(resource.Resource): + def __init__(self, resource: object) -> None: + super().__init__() + self.resource = resource + + def render_GET(self, request) -> bytes: + request.setHeader(b"content-type", b"application/json") + request.setHeader(b"access-control-allow-origin", b"*") + return json.dumps(self.resource).encode("utf-8") diff --git a/synapse_token_authenticator/token_authenticator.py b/synapse_token_authenticator/token_authenticator.py index 69460ec..887b978 100644 --- a/synapse_token_authenticator/token_authenticator.py +++ b/synapse_token_authenticator/token_authenticator.py @@ -12,13 +12,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from __future__ import annotations + import base64 -import json import logging import re -from collections.abc import Awaitable -from typing import Callable, List, Optional, Tuple -from urllib.parse import urljoin +from collections.abc import Awaitable, Callable +from typing import TypeAlias import synapse from jwcrypto import jwk, jwt @@ -26,15 +26,18 @@ from jwcrypto.jwk import JWKSet from synapse.api.errors import HttpResponseException from synapse.module_api import ModuleApi -from synapse.types import UserID -from twisted.internet import defer -from twisted.web import resource +from synapse.rest.client.login import LoginResponse +from synapse.types import JsonDict, UserID +from synapse_token_authenticator.auth_headers import basic_auth from synapse_token_authenticator.config import TokenAuthenticatorConfig -from synapse_token_authenticator.utils import ( +from synapse_token_authenticator.rest import ( + LoginMetadataResource, MetadataResource, + PublicKeysResource, +) +from synapse_token_authenticator.utils import ( all_list_elems_are_equal_return_the_elem, - basic_auth, get_oidp_metadata, get_path_in_dict, if_not_none, @@ -44,109 +47,88 @@ logger = logging.getLogger(__name__) +TypeTokenAuthReturn: TypeAlias = tuple[ + str, + Callable[[LoginResponse], Awaitable[None]] | None, +] + + class TokenAuthenticator: __version__ = "0.13.1" - def __init__(self, config: dict, account_handler: ModuleApi): - self.api = account_handler + def __init__(self, config: TokenAuthenticatorConfig, module_api: ModuleApi) -> None: + self.api = module_api + self.config = config - auth_checkers = {} + auth_checkers: dict[ + tuple[str, tuple[str, ...]], + Callable[[str, str, JsonDict], Awaitable[TypeTokenAuthReturn | None]], + ] = {} - self.config: TokenAuthenticatorConfig = config - if (jwt := getattr(self.config, "jwt", None)) is not None: - if jwt.secret: + if self.config.jwt: + if self.config.jwt.secret: k = { - "k": base64.urlsafe_b64encode(jwt.secret.encode("utf-8")).decode( - "utf-8" - ), + "k": base64.urlsafe_b64encode( + self.config.jwt.secret.encode("utf-8") + ).decode("utf-8"), "kty": "oct", } self.key = jwk.JWK(**k) else: - with open(jwt.keyfile) as f: + with open(self.config.jwt.keyfile) as f: self.key = jwk.JWK.from_pem(f.read()) auth_checkers[("com.famedly.login.token", ("token",))] = self.check_jwt_auth - if (oidc := getattr(self.config, "oidc", None)) is not None: + if self.config.oidc: auth_checkers[("com.famedly.login.token.oidc", ("token",))] = ( self.check_oidc_auth ) self.api.register_web_resource( "/_famedly/login/com.famedly.login.token.oidc", - self.LoginMetadataResource(oidc), + LoginMetadataResource(self.config.oidc), ) - if (cfg := getattr(self.config, "oauth", None)) is not None: - if cfg.expose_metadata_resource: - resource_name = cfg.expose_metadata_resource["name"] + if self.config.oauth: + if self.config.oauth.expose_metadata_resource: + resource_name = self.config.oauth.expose_metadata_resource["name"] self.api.register_web_resource( f"/_famedly/login/{resource_name}", - MetadataResource(cfg.expose_metadata_resource), + MetadataResource(self.config.oauth.expose_metadata_resource), ) auth_checkers[("com.famedly.login.token.oauth", ("token",))] = ( self.check_oauth ) - if (cfg := getattr(self.config, "epa", None)) is not None: - if cfg.expose_metadata_resource: - resource_name = cfg.expose_metadata_resource["name"] + if self.config.epa: + if self.config.epa.expose_metadata_resource: + resource_name = self.config.epa.expose_metadata_resource["name"] self.api.register_web_resource( f"/_famedly/login/{resource_name}", - MetadataResource(cfg.expose_metadata_resource), + MetadataResource(self.config.epa.expose_metadata_resource), ) # Registers the encryption public keys keys = JWKSet() keys.add(self.config.epa.enc_jwk) self.api.register_web_resource( - self.config.epa.enc_jwks_endpoint, self.PublicKeysResource(keys) + self.config.epa.enc_jwks_endpoint, PublicKeysResource(keys) ) auth_checkers[("com.famedly.login.token.epa", ("token",))] = self.check_epa self.api.register_password_auth_provider_callbacks(auth_checkers=auth_checkers) - class LoginMetadataResource(resource.Resource): - def __init__(self, oidc_config: object): - self.issuer = oidc_config.issuer - self.metadata_url = urljoin( - oidc_config.issuer, "/.well-known/openid-configuration" - ) - self.organization_id = oidc_config.organization_id - self.project_id = oidc_config.project_id - - def render_GET(self, request): - request.setHeader(b"content-type", b"application/json") - request.setHeader(b"access-control-allow-origin", b"*") - return json.dumps( - { - "issuer": self.issuer, - "issuer-metadata": self.metadata_url, - "organization-id": self.organization_id, - "project-id": self.project_id, - } - ).encode("utf-8") - - class PublicKeysResource(resource.Resource): - def __init__(self, keys: JWKSet): - self.keys = keys.export(private_keys=False).encode("utf-8") - - def render_GET(self, request): - request.setHeader(b"content-type", b"application/json") - request.setHeader(b"access-control-allow-origin", b"*") - return self.keys - async def check_jwt_auth( - self, username: str, login_type: str, login_dict: "synapse.module_api.JsonDict" - ) -> Optional[ - tuple[ - str, - Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], - ] - ]: + self, username: str, login_type: str, login_dict: JsonDict + ) -> TypeTokenAuthReturn | None: logger.info("Receiving auth request") + + jwt_cfg = self.config.jwt + # Help mypy figure out that this is an actual JwtConfig + assert jwt_cfg is not None + if login_type != "com.famedly.login.token": logger.info("Wrong login type") return None @@ -156,7 +138,8 @@ async def check_jwt_auth( token = login_dict["token"] check_claims: dict = {} - if self.config.jwt.require_expiry: + + if jwt_cfg.require_expiry: check_claims["exp"] = None try: # OK, let's verify the token @@ -164,7 +147,7 @@ async def check_jwt_auth( jwt=token, key=self.key, check_claims=check_claims, - algs=[self.config.jwt.algorithm], + algs=[jwt_cfg.algorithm], ) except ValueError as e: logger.info("Unrecognized token %s", e) @@ -206,7 +189,7 @@ async def check_jwt_auth( return None user_exists = await self.api.check_user_exists(user_id_str) - if not user_exists and not self.config.jwt.allow_registration: + if not user_exists and not jwt_cfg.allow_registration: logger.info("User doesn't exist and registration is disabled") return None @@ -231,14 +214,14 @@ async def check_jwt_auth( return (user_id_str, None) async def check_oidc_auth( - self, username: str, login_type: str, login_dict: "synapse.module_api.JsonDict" - ) -> Optional[ - tuple[ - str, - Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], - ] - ]: + self, username: str, login_type: str, login_dict: JsonDict + ) -> TypeTokenAuthReturn | None: logger.info("Receiving auth request") + + oidc_cfg = self.config.oidc + # Help mypy figure out that this is an actual OIDCConfig + assert oidc_cfg is not None + if login_type != "com.famedly.login.token.oidc": logger.info("Wrong login type") return None @@ -248,8 +231,8 @@ async def check_oidc_auth( token = login_dict["token"] client = self.api._hs.get_proxied_http_client() - oidc = self.config.oidc - oidc_metadata = await get_oidp_metadata(oidc.issuer, client) + + oidc_metadata = await get_oidp_metadata(oidc_cfg.issuer, client) # Further validation using token introspection data = {"token": token, "token_type_hint": "access_token", "scope": "openid"} @@ -258,7 +241,7 @@ async def check_oidc_auth( introspection_resp = await client.post_urlencoded_get_json( oidc_metadata.introspection_endpoint, data, - headers=basic_auth(oidc.client_id, oidc.client_secret), + headers=basic_auth(oidc_cfg.client_id, oidc_cfg.client_secret), ) except HttpResponseException as e: if e.code == 401: @@ -277,7 +260,7 @@ async def check_oidc_auth( [ role in allowed_roles for role in introspection_resp[ - f"urn:zitadel:iam:org:project:{oidc.project_id}:roles" + f"urn:zitadel:iam:org:project:{oidc_cfg.project_id}:roles" ] ] ): @@ -289,8 +272,8 @@ async def check_oidc_auth( return None if ( - oidc.allowed_client_ids is not None - and introspection_resp["client_id"] not in oidc.allowed_client_ids + oidc_cfg.allowed_client_ids is not None + and introspection_resp["client_id"] not in oidc_cfg.allowed_client_ids ): logger.info( f"Client {introspection_resp['client_id']} is not in the list of allowed clients" @@ -306,7 +289,7 @@ async def check_oidc_auth( return None user_exists = await self.api.check_user_exists(user_id_str) - if not user_exists and not self.config.oidc.allow_registration: + if not user_exists and not oidc_cfg.allow_registration: logger.info("User doesn't exist and registration is disabled") return None @@ -320,15 +303,14 @@ async def check_oidc_auth( return (user_id_str, None) async def check_oauth( - self, username: str, login_type: str, login_dict: "synapse.module_api.JsonDict" - ) -> Optional[ - tuple[ - str, - Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], - ] - ]: - config = self.config.oauth + self, username: str, login_type: str, login_dict: JsonDict + ) -> TypeTokenAuthReturn | None: logger.info("Receiving auth request") + + oauth_config = self.config.oauth + # Help mypy figure out that this is an actual OAuthConfig + assert oauth_config is not None + if login_type != "com.famedly.login.token.oauth": logger.info("Wrong login type") return None @@ -341,19 +323,19 @@ async def check_oauth( jwt_claims = {} - if config.jwt_validation is not None: + if oauth_config.jwt_validation is not None: check_claims: dict = {} - if config.jwt_validation.require_expiry: + if oauth_config.jwt_validation.require_expiry: check_claims["exp"] = None - if config.jwt_validation.jwks_endpoint: + if oauth_config.jwt_validation.jwks_endpoint: jwks_json = await client.get_raw( - config.jwt_validation.jwks_endpoint, + oauth_config.jwt_validation.jwks_endpoint, ) - config.jwt_validation.jwk_set = JWKSet.from_json(jwks_json) + oauth_config.jwt_validation.jwk_set = JWKSet.from_json(jwks_json) try: token = jwt.JWT( jwt=token, - key=config.jwt_validation.jwk_set, + key=oauth_config.jwt_validation.jwk_set, check_claims=check_claims, ) except ValueError as e: @@ -365,30 +347,30 @@ async def check_oauth( jwt_claims = json_decode(token.claims) - if config.jwt_validation.required_scopes: + if oauth_config.jwt_validation.required_scopes: provided_scope = jwt_claims.get("scope") if not isinstance(provided_scope, str): logger.info("Token missing scope claim") return None if not validate_scopes( - config.jwt_validation.required_scopes, provided_scope + oauth_config.jwt_validation.required_scopes, provided_scope ): logger.info("Token scope validation failed") return None - if not config.jwt_validation.validator.validate(jwt_claims): + if not oauth_config.jwt_validation.validator.validate(jwt_claims): logger.info("Token claims validation failed") return None introspection_claims = {} - if config.introspection_validation is not None: + if oauth_config.introspection_validation is not None: try: introspection_claims = await client.post_urlencoded_get_json( - config.introspection_validation.endpoint, + oauth_config.introspection_validation.endpoint, {"token": token}, - headers=config.introspection_validation.auth.header_map(), + headers=oauth_config.introspection_validation.auth.header_map(), ) except HttpResponseException as e: if e.code == 401: @@ -397,19 +379,20 @@ async def check_oauth( else: raise e - if config.introspection_validation.required_scopes: + if oauth_config.introspection_validation.required_scopes: provided_scope = introspection_claims.get("scope") if not isinstance(provided_scope, str): logger.info("Token missing scope claim") return None if not validate_scopes( - config.introspection_validation.required_scopes, provided_scope + oauth_config.introspection_validation.required_scopes, + provided_scope, ): logger.info("Token scope validation failed") return None - if not config.introspection_validation.validator.validate( + if not oauth_config.introspection_validation.validator.validate( introspection_claims ): logger.info("Introspection response validation failed for a token") @@ -420,16 +403,18 @@ async def check_oauth( def get_from_set(set_): return if_not_none(lambda path: get_path_in_dict(path, set_)) - username_type = config.username_type + username_type = oauth_config.username_type try: get_localpart_mb = if_not_none(lambda x: x.localpart_path) localpart = all_list_elems_are_equal_return_the_elem( [ - get_from_set(jwt_claims)(get_localpart_mb(config.jwt_validation)), + get_from_set(jwt_claims)( + get_localpart_mb(oauth_config.jwt_validation) + ), get_from_set(introspection_claims)( - get_localpart_mb(config.introspection_validation) + get_localpart_mb(oauth_config.introspection_validation) ), username if username_type == "localpart" else None, ( @@ -451,9 +436,11 @@ def get_from_set(set_): fully_qualified_uid = all_list_elems_are_equal_return_the_elem( [ - get_from_set(jwt_claims)(get_fq_uid_mb(config.jwt_validation)), + get_from_set(jwt_claims)( + get_fq_uid_mb(oauth_config.jwt_validation) + ), get_from_set(introspection_claims)( - get_fq_uid_mb(config.introspection_validation) + get_fq_uid_mb(oauth_config.introspection_validation) ), username if username_type == "fq_uid" else None, ( @@ -486,9 +473,11 @@ def get_from_set(set_): get_displayname_mb = if_not_none(lambda x: x.displayname_path) displayname = all_list_elems_are_equal_return_the_elem( [ - get_from_set(jwt_claims)(get_displayname_mb(config.jwt_validation)), + get_from_set(jwt_claims)( + get_displayname_mb(oauth_config.jwt_validation) + ), get_from_set(introspection_claims)( - get_displayname_mb(config.introspection_validation) + get_displayname_mb(oauth_config.introspection_validation) ), ] ) @@ -500,9 +489,9 @@ def get_from_set(set_): get_admin_mb = if_not_none(lambda x: x.admin_path) admin = all_list_elems_are_equal_return_the_elem( [ - get_from_set(jwt_claims)(get_admin_mb(config.jwt_validation)), + get_from_set(jwt_claims)(get_admin_mb(oauth_config.jwt_validation)), get_from_set(introspection_claims)( - get_admin_mb(config.introspection_validation) + get_admin_mb(oauth_config.introspection_validation) ), ] ) @@ -514,9 +503,9 @@ def get_from_set(set_): get_email_mb = if_not_none(lambda x: x.email_path) email = all_list_elems_are_equal_return_the_elem( [ - get_from_set(jwt_claims)(get_email_mb(config.jwt_validation)), + get_from_set(jwt_claims)(get_email_mb(oauth_config.jwt_validation)), get_from_set(introspection_claims)( - get_email_mb(config.introspection_validation) + get_email_mb(oauth_config.introspection_validation) ), ] ) @@ -554,28 +543,28 @@ def get_from_set(set_): user_exists = await self.api.check_user_exists(fully_qualified_uid) - if not user_exists and not config.registration_enabled: + if not user_exists and not oauth_config.registration_enabled: logger.info("User doesn't exist and registration is disabled") return None if not user_exists: logger.info("User doesn't exist, registering them...") - if config.notify_on_registration: + if oauth_config.notify_on_registration: try: await client.post_json_get_json( - config.notify_on_registration.url, + oauth_config.notify_on_registration.url, { "localpart": localpart, "fully_qualified_uid": fully_qualified_uid, "displayname": displayname, }, - headers=config.notify_on_registration.auth.header_map(), + headers=oauth_config.notify_on_registration.auth.header_map(), ) except ValueError: pass except HttpResponseException as e: logger.info(e) - if config.notify_on_registration.interrupt_on_error: + if oauth_config.notify_on_registration.interrupt_on_error: return None user_id = await self.api.register_user(localpart, admin=bool(admin)) @@ -595,7 +584,7 @@ def get_from_set(set_): logger.info("Registered user %s (%s)", localpart, displayname) - if config.check_external_id and user_exists: + if oauth_config.check_external_id and user_exists: external_ids = await self._get_external_id(fully_qualified_uid) if ( len(external_ids) > 0 @@ -620,15 +609,14 @@ def get_from_set(set_): return (fully_qualified_uid, None) async def check_epa( - self, _username: str, login_type: str, login_dict: "synapse.module_api.JsonDict" - ) -> Optional[ - tuple[ - str, - Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], - ] - ]: - config = self.config.epa + self, _username: str, login_type: str, login_dict: JsonDict + ) -> TypeTokenAuthReturn | None: logger.info("Receiving auth request") + + epa_config = self.config.epa + # Makes Mypy realize that this is an actual EPaConfig + assert epa_config is not None + if login_type != "com.famedly.login.token.epa": logger.info("Wrong login type") return None @@ -637,22 +625,22 @@ async def check_epa( return None token = login_dict["token"] - if config.jwks_endpoint: + if epa_config.jwks_endpoint: client = self.api._hs.get_proxied_http_client() jwks_json = await client.get_raw( - config.jwks_endpoint, + epa_config.jwks_endpoint, ) - config.jwk_set = JWKSet.from_json(jwks_json) + epa_config.jwk_set = JWKSet.from_json(jwks_json) check_claims: dict = { - "iss": config.iss, + "iss": epa_config.iss, "exp": None, } try: - enc_token = jwt.JWT(key=config.enc_jwk, jwt=token, expected_type="JWE") + enc_token = jwt.JWT(key=epa_config.enc_jwk, jwt=token, expected_type="JWE") token = jwt.JWT( jwt=enc_token.claims, - key=config.jwk_set, + key=epa_config.jwk_set, check_claims=check_claims, ) except ValueError as e: @@ -685,23 +673,23 @@ async def check_epa( if "aud" not in jwt_claims: logger.info("Token missing 'aud' claim") return None - if config.resource_id != jwt_claims["aud"]: + if epa_config.resource_id != jwt_claims["aud"]: logger.info( - f"Token has the wrong 'aud'. The expected value is '{config.resource_id}'" + f"Token has the wrong 'aud'. The expected value is '{epa_config.resource_id}'" ) return None - localpart = get_path_in_dict(config.localpart_path, jwt_claims) - displayname = get_path_in_dict(config.displayname_path, jwt_claims) + localpart = get_path_in_dict(epa_config.localpart_path, jwt_claims) + displayname = get_path_in_dict(epa_config.displayname_path, jwt_claims) if not localpart: logger.info("Missing localpart") return None - if config.lowercase_localpart: + if epa_config.lowercase_localpart: localpart = localpart.lower() - if not config.validator.validate(jwt_claims): + if not epa_config.validator.validate(jwt_claims): logger.info("Token claims validation failed") return None @@ -709,7 +697,7 @@ async def check_epa( user_exists = await self.api.check_user_exists(fully_qualified_uid) - if not user_exists and not config.registration_enabled: + if not user_exists and not epa_config.registration_enabled: logger.info("User doesn't exist and registration is disabled") return None @@ -731,19 +719,13 @@ async def check_epa( return (fully_qualified_uid, None) @staticmethod - def parse_config(config: dict): + def parse_config(config: dict) -> TokenAuthenticatorConfig: return TokenAuthenticatorConfig(config) - def _add_user_email(self, user_id, email) -> defer.Deferred: - return defer.ensureDeferred( - self.api._auth_handler.add_threepid( - user_id, "email", email, self.api._hs.get_clock().time_msec() - ) + async def _add_user_email(self, user_id: str, email: str) -> None: + return await self.api._auth_handler.add_threepid( + user_id, "email", email, self.api._hs.get_clock().time_msec() ) - def _get_external_id( - self, fully_qualified_uid: str - ) -> "defer.Deferred[List[Tuple[str, str]]]": - return defer.ensureDeferred( - self.api._store.get_external_ids_by_user(fully_qualified_uid) - ) + async def _get_external_id(self, fully_qualified_uid: str) -> list[tuple[str, str]]: + return await self.api._store.get_external_ids_by_user(fully_qualified_uid) diff --git a/synapse_token_authenticator/utils.py b/synapse_token_authenticator/utils.py index 322caf5..47e7b3d 100644 --- a/synapse_token_authenticator/utils.py +++ b/synapse_token_authenticator/utils.py @@ -1,17 +1,13 @@ -import json -from base64 import b64encode from typing import Any, List, Optional from urllib.parse import urljoin -from twisted.web import resource - class OpenIDProviderMetadata: """ Wrapper around OpenID Provider Metadata values """ - def __init__(self, issuer: str, configuration: dict): + def __init__(self, issuer: str, configuration: dict) -> None: self.issuer = issuer self.introspection_endpoint: str = configuration["introspection_endpoint"] self.jwks_uri: str = configuration["jwks_uri"] @@ -27,17 +23,6 @@ async def get_oidp_metadata(issuer, client) -> OpenIDProviderMetadata: return OpenIDProviderMetadata(issuer, config) -def basic_auth(username: str, password: str) -> dict[bytes, list[bytes]]: - authorization = b64encode( - b":".join((username.encode("utf8"), password.encode("utf8"))) - ) - return {b"Authorization": [b"Basic " + authorization]} - - -def bearer_auth(token: str) -> dict[bytes, list[bytes]]: - return {b"Authorization": [b"Bearer " + token.encode("utf8")]} - - def if_not_none(f): return lambda x: (f(x) if x is not None else None) @@ -74,13 +59,3 @@ def validate_scopes(required_scopes: str | List[str], provided_scopes: str) -> b required_scopes = required_scopes.split() provided_scopes_list = provided_scopes.split() return all(scope in provided_scopes_list for scope in required_scopes) - - -class MetadataResource(resource.Resource): - def __init__(self, resource: object): - self.resource = resource - - def render_GET(self, request): - request.setHeader(b"content-type", b"application/json") - request.setHeader(b"access-control-allow-origin", b"*") - return json.dumps(self.resource).encode("utf-8") diff --git a/tests/__init__.py b/tests/__init__.py index a6f69a6..299b0a5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -30,7 +30,7 @@ import tests.unittest as synapsetest from tests.test_utils import FakeResponse as Response -admins = {} +admins: dict[str, bool] = {} logger = logging.getLogger(__name__) ENC_JWK = jwk.JWK.generate(kty="RSA", size=2048) # secrets for token generation need to be 64 chars long, as it needs to have 512 bits @@ -40,17 +40,17 @@ class ModuleApiTestCase(synapsetest.HomeserverTestCase): @classmethod - def setUpClass(cls): - async def set_user_admin(user_id: str, admin: bool): + def setUpClass(cls) -> None: + async def set_user_admin(user_id: str, admin: bool) -> None: return admins.update({user_id: admin}) - async def is_user_admin(user_id: str): + async def is_user_admin(user_id: str) -> bool: return admins.get(user_id, False) async def register_user( localpart: str, admin: bool = False, - ): + ) -> str: return "@alice:example.test" cls.patchers = [ @@ -89,11 +89,13 @@ def tearDownClass(cls): def prepare( self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: + super().prepare(reactor, clock, homeserver) self.store = homeserver.get_datastores().main self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() self.sync_handler = homeserver.get_sync_handler() self.auth_handler = homeserver.get_auth_handler() + self.token_authenticator = homeserver.mockmod # type: ignore[attr-defined] @override def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -212,7 +214,12 @@ def get_jwe_token( "typ": "JWE", "kid": enc_key.key_id, } - jwetoken = jwe.JWE(token, recipient=enc_key.public(), protected=protected_header) + # The recipient kwarg is mistyped in type-sched. It should be `JWK | None` and is + # instead labeled as a `str | None`. The `public()` function is correct so this will + # be ignored. + jwetoken = jwe.JWE( + token, recipient=enc_key.public(), protected=json.dumps(protected_header) # type: ignore[arg-type] + ) return jwetoken.serialize(True) diff --git a/tests/test_epa.py b/tests/test_epa.py index d607f86..6603429 100644 --- a/tests/test_epa.py +++ b/tests/test_epa.py @@ -17,13 +17,14 @@ from unittest import mock from jwcrypto import jwk +from synapse.types import JsonDict import tests.unittest as synapsetest from . import ModuleApiTestCase, get_enc_jwk, get_jwe_token, get_jwk, get_jwt_token -def get_default_claims() -> dict: +def get_default_claims() -> JsonDict: return { "aud": "https://famedly.de", "jti": "666f3725783e5356544fce5d869", @@ -32,115 +33,115 @@ def get_default_claims() -> dict: class CustomFlowTests(ModuleApiTestCase): - async def test_wrong_login_type(self): + async def test_wrong_login_type(self) -> None: token = get_jwe_token("alice", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) - async def test_missing_token(self): - result = await self.hs.mockmod.check_epa( + async def test_missing_token(self) -> None: + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {} ) self.assertEqual(result, None) - async def test_invalid_token(self): - result = await self.hs.mockmod.check_epa( + async def test_invalid_token(self) -> None: + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": "invalid"} ) self.assertEqual(result, None) - async def test_token_wrong_secret(self): + async def test_token_wrong_secret(self) -> None: # The secret needs to be 64 bytes, so pad it and bulk copy it. 16 * 4 = 64 secret = "wrong secret1234" * 4 token = get_jwe_token("alice", secret=secret, claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) - async def test_token_expired(self): + async def test_token_expired(self) -> None: token = get_jwe_token("alice", exp_in=-60, claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) - async def test_token_no_expiry(self): + async def test_token_no_expiry(self) -> None: token = get_jwe_token("alice", exp_in=-1, claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) - async def test_username_ignored(self): + async def test_username_ignored(self) -> None: token = get_jwe_token("alice", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "dont_match", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") - async def test_token_missing_typ(self): + async def test_token_missing_typ(self) -> None: token = get_jwe_token("alice", claims=get_default_claims(), extra_headers={}) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) - async def test_token_wrong_typ(self): + async def test_token_wrong_typ(self) -> None: token = get_jwe_token( "alice", claims=get_default_claims(), extra_headers={"typ": "wrong"} ) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) - async def test_token_missing_aud(self): + async def test_token_missing_aud(self) -> None: claims = get_default_claims() claims.pop("aud") token = get_jwe_token("alice", claims=claims) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) - async def test_login(self): + async def test_login(self) -> None: token = get_jwe_token("alice", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") - async def test_login_alternative_typ(self): + async def test_login_alternative_typ(self) -> None: token = get_jwe_token( "alice", claims=get_default_claims(), extra_headers={"typ": "application/at+jwt"}, ) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") - async def test_token_missing_jti(self): + async def test_token_missing_jti(self) -> None: claims = get_default_claims() claims.pop("jti") token = get_jwe_token("alice", claims=claims) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) - async def test_token_token_not_enc(self): + async def test_token_token_not_enc(self) -> None: token = get_jwt_token("alice", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) - config_for_epa = { + config_for_epa: JsonDict = { "modules": [ { "module": "synapse_token_authenticator.TokenAuthenticator", @@ -163,9 +164,9 @@ async def test_token_token_not_enc(self): config_for_epa_wrong_iss["modules"][0]["config"]["epa"]["iss"] = "wrong_iss" @synapsetest.override_config(config_for_epa_wrong_iss) - async def test_token_wrong_iss(self): + async def test_token_wrong_iss(self) -> None: token = get_jwe_token("alice", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) @@ -174,17 +175,17 @@ async def test_token_wrong_iss(self): config_for_epa_wrong_aud["modules"][0]["config"]["epa"]["resource_id"] = "wrong_aud" @synapsetest.override_config(config_for_epa_wrong_aud) - async def test_token_wrong_aud(self): + async def test_token_wrong_aud(self) -> None: token = get_jwe_token("alice", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) - async def test_valid_login_register(self, *args): + async def test_valid_login_register(self, *args) -> None: token = get_jwe_token("alice", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") @@ -201,9 +202,9 @@ async def test_valid_login_register(self, *args): @mock.patch( "synapse.http.client.SimpleHttpClient.get_raw", return_value=jwks.export() ) - async def test_fetch_jwks(self, *args): + async def test_fetch_jwks(self, *args) -> None: token = get_jwe_token("alice", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") @@ -215,9 +216,9 @@ async def test_fetch_jwks(self, *args): @synapsetest.override_config(config_for_epa_reg_disabled) @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) - async def test_valid_login_registration_disabled(self, *args): + async def test_valid_login_registration_disabled(self, *args) -> None: token = get_jwe_token("alice", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) @@ -228,16 +229,16 @@ async def test_valid_login_registration_disabled(self, *args): ] = True @synapsetest.override_config(config_for_epa_lowercase) - async def test_localpart_lowercase(self): + async def test_localpart_lowercase(self) -> None: token = get_jwe_token("AlIcE", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") - async def test_localpart_not_lowercase(self): + async def test_localpart_not_lowercase(self) -> None: token = get_jwe_token("AlIcE", claims=get_default_claims()) - result = await self.hs.mockmod.check_epa( + result = await self.token_authenticator.check_epa( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result[0], "@AlIcE:example.test") diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 0cab91f..031cf37 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -21,51 +21,51 @@ class JWTTests(ModuleApiTestCase): - async def test_wrong_login_type(self): + async def test_wrong_login_type(self) -> None: token = get_jwt_token("alice") - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "m.password", {"token": token} ) self.assertEqual(result, None) - async def test_missing_token(self): - result = await self.hs.mockmod.check_jwt_auth( + async def test_missing_token(self) -> None: + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {} ) self.assertEqual(result, None) - async def test_invalid_token(self): - result = await self.hs.mockmod.check_jwt_auth( + async def test_invalid_token(self) -> None: + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": "invalid"} ) self.assertEqual(result, None) - async def test_token_wrong_secret(self): + async def test_token_wrong_secret(self) -> None: # The secret needs to be 64 bytes, so pad it and bulk copy it. 16 * 4 = 64 secret = "wrong secret1234" * 4 token = get_jwt_token("alice", secret=secret) - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) - async def test_token_wrong_alg(self): + async def test_token_wrong_alg(self) -> None: token = get_jwt_token("alice", algorithm="HS256") - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) - async def test_token_expired(self): + async def test_token_expired(self) -> None: token = get_jwt_token("alice", exp_in=-60) - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) - async def test_token_no_expiry(self): + async def test_token_no_expiry(self) -> None: token = get_jwt_token("alice", exp_in=-1) - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) @@ -85,33 +85,33 @@ async def test_token_no_expiry(self): ] } ) - async def test_token_no_expiry_with_config(self, *args): + async def test_token_no_expiry_with_config(self, *args) -> None: token = get_jwt_token("alice", exp_in=-1) - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") - async def test_valid_login(self): + async def test_valid_login(self) -> None: token = get_jwt_token("alice") - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) - async def test_valid_login_no_register(self, *args): + async def test_valid_login_no_register(self, *args) -> None: token = get_jwt_token("alice") - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) - async def test_chatbox_login(self): + async def test_chatbox_login(self) -> None: token = get_jwt_token( "alice_5833eb34-7dbf-44a7-90cf-868c50922c06", claims={"type": "chatbox"} ) - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice_5833eb34-7dbf-44a7-90cf-868c50922c06", "com.famedly.login.token", {"token": token}, @@ -121,9 +121,9 @@ async def test_chatbox_login(self): ) @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) - async def test_chatbox_login_invalid_format(self, *args): + async def test_chatbox_login_invalid_format(self, *args) -> None: token = get_jwt_token("alice", claims={"type": "chatbox"}) - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) @@ -144,16 +144,16 @@ async def test_chatbox_login_invalid_format(self, *args): ] } ) - async def test_valid_login_with_register(self, *args): + async def test_valid_login_with_register(self, *args) -> None: token = get_jwt_token("alice") - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") - async def test_valid_login_with_admin(self): + async def test_valid_login_with_admin(self) -> None: token = get_jwt_token("alice", admin=True) - result = await self.hs.mockmod.check_jwt_auth( + result = await self.token_authenticator.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") diff --git a/tests/test_oauth.py b/tests/test_oauth.py index da239cc..8b44120 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -17,6 +17,7 @@ from unittest import mock from jwcrypto.jwk import JWKSet +from synapse.types import JsonDict import tests.unittest as synapsetest @@ -37,74 +38,74 @@ class CustomFlowTests(ModuleApiTestCase): - async def test_wrong_login_type(self): + async def test_wrong_login_type(self) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) - async def test_missing_token(self): - result = await self.hs.mockmod.check_oauth( + async def test_missing_token(self) -> None: + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {} ) self.assertEqual(result, None) - async def test_invalid_token(self): - result = await self.hs.mockmod.check_oauth( + async def test_invalid_token(self) -> None: + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": "invalid"} ) self.assertEqual(result, None) - async def test_token_wrong_secret(self): + async def test_token_wrong_secret(self) -> None: # The secret needs to be 64 bytes, so pad it and bulk copy it. 16 * 4 = 64 secret = "wrong secret1234" * 4 token = get_jwt_token("aliceid", secret=secret, claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) - async def test_token_expired(self): + async def test_token_expired(self) -> None: token = get_jwt_token("aliceid", exp_in=-60, claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) - async def test_token_no_expiry(self): + async def test_token_no_expiry(self) -> None: token = get_jwt_token("aliceid", exp_in=-1, claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) - async def test_token_bad_localpart(self): + async def test_token_bad_localpart(self) -> None: claims = default_claims.copy() claims["urn:messaging:matrix:localpart"] = "bobby" token = get_jwt_token("aliceid", claims=claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) - async def test_token_bad_mxid(self): + async def test_token_bad_mxid(self) -> None: claims = default_claims.copy() claims["urn:messaging:matrix:mxid"] = "@bobby:example.test" token = get_jwt_token("aliceid", claims=claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) - async def test_token_claims_username_mismatch(self): + async def test_token_claims_username_mismatch(self) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "bobby", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) - config_for_jwt = { + config_for_jwt: JsonDict = { "modules": [ { "module": "synapse_token_authenticator.TokenAuthenticator", @@ -129,9 +130,9 @@ async def test_token_claims_username_mismatch(self): @synapsetest.override_config(config_for_jwt_reg_disabled) @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) - async def test_valid_login_registration_disabled(self, *args): + async def test_valid_login_registration_disabled(self, *args) -> None: token = get_jwt_token("alice", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.epa", {"token": token} ) self.assertEqual(result, None) @@ -142,9 +143,9 @@ async def test_valid_login_registration_disabled(self, *args): new_callable=mock.AsyncMock, return_value=[], ) - async def test_token_no_expiry_with_config(self, *args): + async def test_token_no_expiry_with_config(self, *args) -> None: token = get_jwt_token("aliceid", exp_in=-1, claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") @@ -154,9 +155,9 @@ async def test_token_no_expiry_with_config(self, *args): new_callable=mock.AsyncMock, return_value=[], ) - async def test_valid_login(self, *args): + async def test_valid_login(self, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") @@ -169,18 +170,18 @@ async def test_valid_login(self, *args): "synapse.module_api.ModuleApi.record_user_external_id", new_callable=mock.AsyncMock, ) - async def test_valid_login_register(self, *args): + async def test_valid_login_register(self, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") - async def test_invalid_scope(self): + async def test_invalid_scope(self) -> None: claims = default_claims.copy() claims["scope"] = "foo" token = get_jwt_token("aliceid", claims=claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) @@ -204,9 +205,9 @@ async def test_invalid_scope(self): new_callable=mock.AsyncMock, return_value=[], ) - async def test_fetch_jwks(self, *args): + async def test_fetch_jwks(self, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") @@ -229,9 +230,9 @@ async def test_fetch_jwks(self, *args): new_callable=mock.AsyncMock, ) @mock.patch("synapse.module_api.ModuleApi.register_user") - async def test_login_register_admin(self, register_user_mock, *args): + async def test_login_register_admin(self, register_user_mock, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) @@ -256,9 +257,11 @@ async def test_login_register_admin(self, register_user_mock, *args): new_callable=mock.AsyncMock, ) @mock.patch("synapse.module_api.ModuleApi.register_user") - async def test_login_register_multiple_admin_paths(self, register_user_mock, *args): + async def test_login_register_multiple_admin_paths( + self, register_user_mock, *args + ) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) @@ -280,9 +283,11 @@ async def test_login_register_multiple_admin_paths(self, register_user_mock, *ar new_callable=mock.AsyncMock, ) @mock.patch("synapse.module_api.ModuleApi.register_user") - async def test_login_register_admin_negative(self, register_user_mock, *args): + async def test_login_register_admin_negative( + self, register_user_mock, *args + ) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) @@ -297,9 +302,11 @@ async def test_login_register_admin_negative(self, register_user_mock, *args): "synapse.module_api.ModuleApi.record_user_external_id", new_callable=mock.AsyncMock, ) - async def test_login_register_external_user_id(self, external_id_mock, *args): + async def test_login_register_external_user_id( + self, external_id_mock, *args + ) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) @@ -328,9 +335,9 @@ async def test_login_register_external_user_id(self, external_id_mock, *args): "synapse_token_authenticator.TokenAuthenticator._add_user_email", new_callable=mock.AsyncMock, ) - async def test_login_register_threepid(self, add_threepid_mock, *args): + async def test_login_register_threepid(self, add_threepid_mock, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) @@ -349,9 +356,9 @@ async def test_login_register_threepid(self, add_threepid_mock, *args): ("http://test.example", "aliceid"), ], ) - async def test_login_check_external_id(self, *args): + async def test_login_check_external_id(self, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") @@ -362,9 +369,9 @@ async def test_login_check_external_id(self, *args): new_callable=mock.AsyncMock, return_value=[("some_auth_provider", "some_external_id")], ) - async def test_login_check_external_id_negative(self, *args): + async def test_login_check_external_id_negative(self, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) @@ -378,14 +385,14 @@ async def test_login_check_external_id_negative(self, *args): new_callable=mock.AsyncMock, return_value=[("some_auth_provider", "some_external_id")], ) - async def test_login_check_external_id_disabled(self, *args): + async def test_login_check_external_id_disabled(self, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") - config_for_introspection = { + config_for_introspection: JsonDict = { "modules": [ { "module": "synapse_token_authenticator.TokenAuthenticator", @@ -416,9 +423,9 @@ async def test_login_check_external_id_disabled(self, *args): "synapse.module_api.ModuleApi.record_user_external_id", new_callable=mock.AsyncMock, ) - async def test_valid_login_introspection(self, *args): + async def test_valid_login_introspection(self, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") @@ -433,9 +440,9 @@ async def test_valid_login_introspection(self, *args): "synapse.http.client.SimpleHttpClient.request", side_effect=mock_for_oauth ) @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) - async def test_login_introspection_notify_fails(self, *args): + async def test_login_introspection_notify_fails(self, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) @@ -456,9 +463,9 @@ async def test_login_introspection_notify_fails(self, *args): "synapse.module_api.ModuleApi.record_user_external_id", new_callable=mock.AsyncMock, ) - async def test_login_introspection_notify_fails_but_ok(self, *args): + async def test_login_introspection_notify_fails_but_ok(self, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result[0], "@alice:example.test") @@ -472,11 +479,11 @@ async def test_login_introspection_notify_fails_but_ok(self, *args): @mock.patch( "synapse.http.client.SimpleHttpClient.request", side_effect=mock_for_oauth ) - async def test_login_introspection_invalid_scope(self, *args): + async def test_login_introspection_invalid_scope(self, *args) -> None: claims = default_claims.copy() claims["scope"] = "foo" token = get_jwt_token("aliceid", claims=claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) self.assertEqual(result, None) @@ -496,9 +503,11 @@ async def test_login_introspection_invalid_scope(self, *args): new_callable=mock.AsyncMock, ) @mock.patch("synapse.module_api.ModuleApi.register_user") - async def test_login_introspection_register_admin(self, register_user_mock, *args): + async def test_login_introspection_register_admin( + self, register_user_mock, *args + ) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) register_user_mock.assert_called_with("alice", admin=True) @@ -521,9 +530,9 @@ async def test_login_introspection_register_admin(self, register_user_mock, *arg @mock.patch("synapse.module_api.ModuleApi.register_user") async def test_login_introspection_register_multiple_admin_paths( self, register_user_mock, *args - ): + ) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) register_user_mock.assert_called_with("alice", admin=True) @@ -537,9 +546,11 @@ async def test_login_introspection_register_multiple_admin_paths( "synapse.module_api.ModuleApi.record_user_external_id", new_callable=mock.AsyncMock, ) - async def test_login_introspection_external_user_id(self, external_id_mock, *args): + async def test_login_introspection_external_user_id( + self, external_id_mock, *args + ) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) external_id_mock.assert_called_with( @@ -567,9 +578,9 @@ async def test_login_introspection_external_user_id(self, external_id_mock, *arg "synapse_token_authenticator.TokenAuthenticator._add_user_email", new_callable=mock.AsyncMock, ) - async def test_login_introspection_threepid(self, add_threepid_mock, *args): + async def test_login_introspection_threepid(self, add_threepid_mock, *args) -> None: token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_oauth( + result = await self.token_authenticator.check_oauth( "alice", "com.famedly.login.token.oauth", {"token": token} ) add_threepid_mock.assert_called_with( diff --git a/tests/test_oidc.py b/tests/test_oidc.py index efc8a7a..376ce9a 100644 --- a/tests/test_oidc.py +++ b/tests/test_oidc.py @@ -21,14 +21,14 @@ class OIDCTests(ModuleApiTestCase): - async def test_wrong_login_type(self): - result = await self.hs.mockmod.check_oidc_auth( + async def test_wrong_login_type(self) -> None: + result = await self.token_authenticator.check_oidc_auth( "alice", "m.password", get_oidc_login("alice") ) self.assertEqual(result, None) - async def test_missing_token(self): - result = await self.hs.mockmod.check_oidc_auth( + async def test_missing_token(self) -> None: + result = await self.token_authenticator.check_oidc_auth( "alice", "com.famedly.login.token,oidc", {} ) self.assertEqual(result, None) @@ -36,8 +36,8 @@ async def test_missing_token(self): @mock.patch( "synapse.http.client.SimpleHttpClient.request", side_effect=mock_idp_req ) - async def test_invalid_token(self, *args): - result = await self.hs.mockmod.check_oidc_auth( + async def test_invalid_token(self, *args) -> None: + result = await self.token_authenticator.check_oidc_auth( "alice", "com.famedly.login.token.oidc", {"token": "invalid"} ) self.assertEqual(result, None) @@ -45,8 +45,8 @@ async def test_invalid_token(self, *args): @mock.patch( "synapse.http.client.SimpleHttpClient.request", side_effect=mock_idp_req ) - async def test_valid_login(self, *args): - result = await self.hs.mockmod.check_oidc_auth( + async def test_valid_login(self, *args) -> None: + result = await self.token_authenticator.check_oidc_auth( "alice", "com.famedly.login.token.oidc", get_oidc_login("alice") ) self.assertEqual(result[0], "@alice:example.test") @@ -73,8 +73,8 @@ async def test_valid_login(self, *args): ] } ) - async def test_valid_login_unicode_client_id(self, *args): - result = await self.hs.mockmod.check_oidc_auth( + async def test_valid_login_unicode_client_id(self, *args) -> None: + result = await self.token_authenticator.check_oidc_auth( "alice", "com.famedly.login.token.oidc", get_oidc_login("alice") ) self.assertEqual(result[0], "@alice:example.test") @@ -83,8 +83,8 @@ async def test_valid_login_unicode_client_id(self, *args): "synapse.http.client.SimpleHttpClient.request", side_effect=mock_idp_req ) @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) - async def test_valid_login_no_register(self, *args): - result = await self.hs.mockmod.check_oidc_auth( + async def test_valid_login_no_register(self, *args) -> None: + result = await self.token_authenticator.check_oidc_auth( "alice", "com.famedly.login.token.oidc", get_oidc_login("alice") ) self.assertEqual(result, None) @@ -112,8 +112,8 @@ async def test_valid_login_no_register(self, *args): ] } ) - async def test_valid_login_with_register(self, *args): - result = await self.hs.mockmod.check_oidc_auth( + async def test_valid_login_with_register(self, *args) -> None: + result = await self.token_authenticator.check_oidc_auth( "alice", "com.famedly.login.token.oidc", get_oidc_login("alice") ) self.assertEqual(result[0], "@alice:example.test") diff --git a/tests/test_sta_utils.py b/tests/test_sta_utils.py index b8c2dad..90d290c 100644 --- a/tests/test_sta_utils.py +++ b/tests/test_sta_utils.py @@ -6,7 +6,7 @@ ) -def test_get_path_in_dict(): +def test_get_path_in_dict() -> None: assert get_path_in_dict("foo", {"foo": 3}) == 3 assert get_path_in_dict("foo", {"loo": 3}) is None assert get_path_in_dict("foo", [3, 4]) is None @@ -38,7 +38,7 @@ def test_get_path_in_dict(): assert get_path_in_dict([["foo", "loo"], []], {"foo": {"loo": 3}}) == 3 -def test_get_path_in_dict_pathlist_fallback_on_missing_key(): +def test_get_path_in_dict_pathlist_fallback_on_missing_key() -> None: """When the first path's key is entirely absent, later paths must still be tried.""" assert ( get_path_in_dict([["missing", "sub"], ["foo", "bar"]], {"foo": {"bar": 3}}) == 3 @@ -54,7 +54,7 @@ def test_get_path_in_dict_pathlist_fallback_on_missing_key(): ) -def test_get_path_in_dict_pathlist_non_dict_intermediate(): +def test_get_path_in_dict_pathlist_non_dict_intermediate() -> None: """When an intermediate value is a non-dict (e.g. int), later paths must still be tried.""" assert ( get_path_in_dict( @@ -71,7 +71,7 @@ def test_get_path_in_dict_pathlist_non_dict_intermediate(): ) -def test_get_path_in_dict_zitadel_admin_path(): +def test_get_path_in_dict_zitadel_admin_path() -> None: """Real-world scenario: Zitadel project-scoped role claims with PathList fallback.""" token = { "urn:zitadel:iam:org:project:12345:roles": { @@ -94,7 +94,7 @@ def test_get_path_in_dict_zitadel_admin_path(): ) == {"org_id": "famedly.localhost"} -def test_validate_scopes(): +def test_validate_scopes() -> None: assert validate_scopes("foo boo", "boo foo") assert validate_scopes(["foo", "boo"], "boo foo") assert not validate_scopes("foo boo", "foo") @@ -102,12 +102,12 @@ def test_validate_scopes(): assert validate_scopes("foo boo", "boo foo loo") -def test_if_not_none(): +def test_if_not_none() -> None: assert if_not_none(lambda x: x + 1)(3) == 4 assert if_not_none(lambda x: x + 1)(None) is None -def test_all_list_elems_are_equal_return_the_elem(): +def test_all_list_elems_are_equal_return_the_elem() -> None: assert all_list_elems_are_equal_return_the_elem([None, None]) is None assert all_list_elems_are_equal_return_the_elem([]) is None assert all_list_elems_are_equal_return_the_elem([3, None]) == 3 diff --git a/tests/test_validators.py b/tests/test_validators.py index fc3533d..167b7d7 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,33 +1,34 @@ from pytest import fixture +from synapse.types import JsonDict from synapse_token_authenticator.claims_validator import parse_validator -def test_validator_exists(): +def test_validator_exists() -> None: assert parse_validator(["exist"]).validate(None) -def test_validator_in(): +def test_validator_in() -> None: assert parse_validator(["in", "foo"]).validate({"foo": 3}) assert not parse_validator(["in", "foo"]).validate({"loo": 3}) assert parse_validator(["in", "foo", ["equal", 3]]).validate({"foo": 3}) assert not parse_validator(["in", "foo", ["equal", 3]]).validate({"foo": 4}) -def test_validator_not(): +def test_validator_not() -> None: assert not parse_validator(["not", ["in", "foo"]]).validate({"foo": 3}) assert parse_validator(["not", ["in", "foo"]]).validate({"loo": 3}) assert not parse_validator(["not", ["exist"]]).validate(None) -def test_validator_equal(): +def test_validator_equal() -> None: assert parse_validator(["equal", 3]).validate(3) assert not parse_validator(["equal", 3]).validate(4) assert parse_validator(["equal", {"hi": 3}]).validate({"hi": 3}) assert not parse_validator(["equal", {"hi": 3}]).validate({"hi": 4}) -def test_validator_regex(): +def test_validator_regex() -> None: txt = "The rain in Spain" regexp = "The.*Spain" assert parse_validator(["regex", regexp]).validate(txt) @@ -35,7 +36,7 @@ def test_validator_regex(): assert not parse_validator(["regex", regexp]).validate("bad string") -def test_validator_all_of(): +def test_validator_all_of() -> None: assert parse_validator(["all_of", [["in", "foo"], ["in", "loo"]]]).validate( {"foo": 3, "loo": 4} ) @@ -45,7 +46,7 @@ def test_validator_all_of(): assert parse_validator(["all_of", []]).validate([]) -def test_validator_any_of(): +def test_validator_any_of() -> None: assert parse_validator(["any_of", [["in", "foo"], ["in", "loo"]]]).validate( {"foo": 3, "loo": 4} ) @@ -58,7 +59,7 @@ def test_validator_any_of(): assert not parse_validator(["any_of", []]).validate({}) -def test_validator_list_all_of(): +def test_validator_list_all_of() -> None: assert parse_validator(["list_all_of", ["in", "foo"]]).validate( [{"foo": 3}, {"foo": 4}] ) @@ -68,7 +69,7 @@ def test_validator_list_all_of(): ) -def test_validator_list_any_of(): +def test_validator_list_any_of() -> None: assert parse_validator(["list_any_of", ["in", "foo"]]).validate( [{"foo": 3}, {"foo": 4}] ) @@ -79,7 +80,7 @@ def test_validator_list_any_of(): @fixture -def jwt_claims(): +def jwt_claims() -> JsonDict: return { "foo": "hello", "bar": "hi", @@ -92,7 +93,7 @@ def jwt_claims(): } -def test_validator_full(jwt_claims): +def test_validator_full(jwt_claims) -> None: required_claims = { "type": "all_of", "validators": [ @@ -124,7 +125,7 @@ def test_validator_full(jwt_claims): assert parse_validator(required_claims).validate(jwt_claims) -def test_validator_short(jwt_claims): +def test_validator_short(jwt_claims) -> None: required_claims_short = [ "all_of", [