From a9718fda627a3375c284f7d032b3a5b4ccc7b72d Mon Sep 17 00:00:00 2001 From: Peter Chubb Date: Mon, 22 Jun 2020 15:26:31 +1000 Subject: [PATCH] Add Digest Authentication to client Uses the python3-digest package to decode the HTTP digest authorisation challenge and create an appropriate extra header to authenticate. --- src/websockets/__init__.py | 1 + src/websockets/client.py | 28 +++++++++++++++++++++++++++- src/websockets/exceptions.py | 16 ++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index ea1d829a3..6b61de9f1 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -12,6 +12,7 @@ __all__ = [ "AbortHandshake", + "AuthenticationRequest", "basic_auth_protocol_factory", "BasicAuthWebSocketServerProtocol", "connect", diff --git a/src/websockets/client.py b/src/websockets/client.py index be055310d..57918ee6c 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -8,10 +8,12 @@ import functools import logging import warnings +from python_digest import parse_digest_challenge, build_authorization_request from types import TracebackType from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast from .exceptions import ( + AuthenticationRequest, InvalidHandshake, InvalidHeader, InvalidMessage, @@ -254,7 +256,7 @@ async def handshake( else: request_headers["Host"] = f"{wsuri.host}:{wsuri.port}" - if wsuri.user_info: + if wsuri.user_info and "Authorization" not in request_headers: request_headers["Authorization"] = build_authorization_basic( *wsuri.user_info ) @@ -294,6 +296,10 @@ async def handshake( if "Location" not in response_headers: raise InvalidHeader("Location") raise RedirectHandshake(response_headers["Location"]) + elif status_code == 401: + if "WWW-Authenticate" not in response_headers: + raise InvalidHeader("WWW-Authenticate") + raise AuthenticationRequest(response_headers["WWW-Authenticate"]) elif status_code != 101: raise InvalidStatusCode(status_code) @@ -479,6 +485,18 @@ def __init__( self._create_connection = create_connection self._wsuri = wsuri + def handle_digest_auth(self, response: str) -> None: + wsuri = self._wsuri + challenge = parse_digest_challenge(response) + if challenge is None: + raise AuthenticationRequest(response) + kd = build_authorization_request(wsuri.user_info[0], + 'GET', wsuri.resource_name, 1, + challenge, + password=wsuri.user_info[1]) + return kd + + def handle_redirect(self, uri: str) -> None: # Update the state of this instance to connect to a new URI. old_wsuri = self._wsuri @@ -533,12 +551,18 @@ def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: return self.__await_impl__().__await__() async def __await_impl__(self) -> WebSocketClientProtocol: + auth_header = None for redirects in range(self.MAX_REDIRECTS_ALLOWED): transport, protocol = await self._create_connection() # https://github.com/python/typeshed/pull/2756 transport = cast(asyncio.Transport, transport) protocol = cast(WebSocketClientProtocol, protocol) + if auth_header is not None: + if protocol.extra_headers is None: + protocol.extra_headers = {} + protocol.extra_headers['Authorization'] = auth_header + try: try: await protocol.handshake( @@ -557,6 +581,8 @@ async def __await_impl__(self) -> WebSocketClientProtocol: return protocol except RedirectHandshake as exc: self.handle_redirect(exc.uri) + except AuthenticationRequest as exc: + auth_header = self.handle_digest_auth(exc.authresponse) else: raise SecurityError("too many redirects") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 9873a1717..d22eeea5a 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -20,6 +20,7 @@ * :exc:`InvalidParameterValue` * :exc:`AbortHandshake` * :exc:`RedirectHandshake` + * :exc:`AuthenticationRequest` * :exc:`InvalidState` * :exc:`InvalidURI` * :exc:`PayloadTooBig` @@ -53,6 +54,7 @@ "InvalidParameterValue", "AbortHandshake", "RedirectHandshake", + "AuthenticationRequest", "InvalidState", "InvalidURI", "PayloadTooBig", @@ -326,6 +328,20 @@ def __str__(self) -> str: return f"redirect to {self.uri}" +class AuthenticationRequest(InvalidHandshake): + """ + Raised when a 401 Unauthorized response is received. + Another implementation detail, to allow e.g., Digest authentication + + """ + def __init__(self, details : str) -> None: + self.authresponse = details + + def __str__(self) -> str: + return f"WWW-Authenticate: {self.authresponse}" + + + class InvalidState(WebSocketException, AssertionError): """ Raised when an operation is forbidden in the current state.