diff --git a/.gitignore b/.gitignore index 555d276e69..b22c7e5f1c 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,9 @@ virtualenv-components virtualenv-components-osx .venv-st2devbox +# st2client build files +st2client/build/* + # generated travis conf conf/st2.travis.conf # generated GitHub Actions conf diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 020ca834e2..79ac438d9d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,15 @@ in development Added ~~~~~ + +* Revised support for SSO backends + SAML2 included by default #5664 + + SAML backend on a separate repository, included as a dependecy , https://github.com/StackStorm/st2-auth-backend-sso-saml2 + + RBAC support also baked into SSO backends + + Contributed by @pimguilherme + * Move `git clone` to `user_home/.st2packs` #5845 * Error on `st2ctl status` when running in Kubernetes. #5851 diff --git a/requirements.txt b/requirements.txt index 3d5395bb0f..87356ef25d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ decorator==4.4.2 dnspython>=1.16.0,<2.0.0 eventlet==0.30.2 flex==6.14.1 +git+https://github.com/pimguilherme/st2-auth-backend-sso-saml2.git@feat/saml#egg=st2-auth-backend-sso-saml2 gitdb==4.0.2 gitpython==3.1.15 greenlet==1.0.0 diff --git a/st2auth/in-requirements.txt b/st2auth/in-requirements.txt index 0d9e5e01a3..87d6be8cb2 100644 --- a/st2auth/in-requirements.txt +++ b/st2auth/in-requirements.txt @@ -6,6 +6,8 @@ passlib pymongo six stevedore +# for SAML sso +git+https://github.com/pimguilherme/st2-auth-backend-sso-saml2.git@feat/saml#egg=st2-auth-backend-sso-saml2 # For backward compatibility reasons, flat file backend is installed by default st2-auth-backend-flat-file@ git+https://github.com/StackStorm/st2-auth-backend-flat-file.git@master st2-auth-ldap@ git+https://github.com/StackStorm/st2-auth-ldap.git@master diff --git a/st2auth/requirements.txt b/st2auth/requirements.txt index 1d6a06de81..13cf8ad85d 100644 --- a/st2auth/requirements.txt +++ b/st2auth/requirements.txt @@ -7,6 +7,7 @@ # update the component requirements.txt bcrypt==3.2.0 eventlet==0.30.2 +git+https://github.com/pimguilherme/st2-auth-backend-sso-saml2.git@feat/saml#egg=st2-auth-backend-sso-saml2 gunicorn==20.1.0 oslo.config>=1.12.1,<1.13 passlib==1.7.4 diff --git a/st2auth/st2auth/controllers/v1/sso.py b/st2auth/st2auth/controllers/v1/sso.py index ef1096462c..959d6081fa 100644 --- a/st2auth/st2auth/controllers/v1/sso.py +++ b/st2auth/st2auth/controllers/v1/sso.py @@ -14,18 +14,30 @@ import datetime import json +from uuid import uuid4 from oslo_config import cfg from six.moves import http_client from six.moves import urllib +from st2common.router import GenericRequestParam import st2auth.handlers as handlers from st2auth import sso as st2auth_sso +from st2auth.sso.base import BaseSingleSignOnBackendResponse from st2common.exceptions import auth as auth_exc from st2common import log as logging from st2common import router - +from st2common.models.db.auth import SSORequestDB +from st2common.services.access import ( + create_cli_sso_request, + create_web_sso_request, + get_sso_request_by_request_id, +) +from st2common.exceptions.auth import SSORequestNotFoundError +from st2common.util.crypto import read_crypto_key_from_dict, symmetric_encrypt +from st2common.util.date import get_datetime_utc_now +from st2common.util.jsonify import json_decode LOG = logging.getLogger(__name__) SSO_BACKEND = st2auth_sso.get_sso_backend() @@ -35,25 +47,96 @@ class IdentityProviderCallbackController(object): def __init__(self): self.st2_auth_handler = handlers.ProxyAuthHandler() + # Validates the incoming SSO response by getting its ID, checking against + # the database for outstanding SSO requests and checking to see if they have already expired + def _validate_and_delete_sso_request(self, response): + + # Grabs the ID from the SSO response based on the backend + request_id = SSO_BACKEND.get_request_id_from_response(response) + if request_id is None: + raise ValueError("Invalid request id coming from SAML response") + + LOG.debug("Validating SSO request %s from received response!", request_id) + + # Grabs the original SSO request based on the ID + original_sso_request = None + try: + original_sso_request = get_sso_request_by_request_id(request_id) + except SSORequestNotFoundError: + pass + + if original_sso_request is None: + raise ValueError( + "This SSO request is invalid (it may have already been used)" + ) + + # Verifies if the request has expired already + LOG.info( + "Incoming SSO response matching request: %s, with expiry: %s", + original_sso_request.request_id, + original_sso_request.expiry, + ) + if original_sso_request.expiry <= get_datetime_utc_now(): + raise ValueError( + "The SSO request associated with this response has already expired!" + ) + + # All done, we should not need to use this again :) + LOG.debug( + "Deleting original SSO request from database with ID %s", + original_sso_request.id, + ) + original_sso_request.delete() + + return original_sso_request + def post(self, response, **kwargs): try: + + original_sso_request = self._validate_and_delete_sso_request(response) + + # Obtain user details from the SSO response from the backend verified_user = SSO_BACKEND.verify_response(response) + if not isinstance(verified_user, BaseSingleSignOnBackendResponse): + return process_failure_response( + http_client.INTERNAL_SERVER_ERROR, + "Unexpected SSO backend response type. Expected " + "BaseSingleSignOnBackendResponse instance!", + ) + + LOG.info( + "Authenticating SSO user [%s] with groups [%s]", + verified_user.username, + verified_user.groups, + ) - st2_auth_token_create_request = { - "user": verified_user["username"], - "ttl": None, - } + st2_auth_token_create_request = GenericRequestParam( + ttl=None, + groups=verified_user.groups, + ) st2_auth_token = self.st2_auth_handler.handle_auth( request=st2_auth_token_create_request, - remote_addr=verified_user["referer"], - remote_user=verified_user["username"], + remote_addr=verified_user.referer, + remote_user=verified_user.username, headers={}, ) - return process_successful_authn_response( - verified_user["referer"], st2_auth_token - ) + # Depending on the type of SSO request we should handle the response differently + # ie WEB gets redirected and CLI gets an encrypted callback + if original_sso_request.type == SSORequestDB.Type.WEB: + return process_successful_sso_web_response( + verified_user.referer, st2_auth_token + ) + elif original_sso_request.type == SSORequestDB.Type.CLI: + return process_successful_sso_cli_response( + verified_user.referer, original_sso_request.key, st2_auth_token + ) + else: + raise NotImplementedError( + "Unexpected SSO request type [%s] -- I can deal with web and cli" + % original_sso_request.type + ) except NotImplementedError as e: return process_failure_response(http_client.INTERNAL_SERVER_ERROR, e) except auth_exc.SSOVerificationError as e: @@ -63,14 +146,76 @@ def post(self, response, **kwargs): class SingleSignOnRequestController(object): - def get(self, referer): + def _create_sso_request(self, handler, **kwargs): + + request_id = "id_%s" % str(uuid4()) + sso_request = handler(request_id=request_id, **kwargs) + LOG.debug( + "Created SSO request with request id %s and expiry %s and type %s", + request_id, + sso_request.expiry, + sso_request.type, + ) + return sso_request + + # web-intended SSO + def get_web(self, referer): try: + sso_request = self._create_sso_request(create_web_sso_request) + response = router.Response(status=http_client.TEMPORARY_REDIRECT) - response.location = SSO_BACKEND.get_request_redirect_url(referer) + response.location = SSO_BACKEND.get_request_redirect_url( + sso_request.request_id, referer + ) return response except NotImplementedError as e: + if sso_request: + sso_request.delete() return process_failure_response(http_client.INTERNAL_SERVER_ERROR, e) except Exception as e: + if sso_request: + sso_request.delete() + raise e + + # cli-intended SSO + def post_cli(self, response): + sso_request = None + try: + key = getattr(response, "key", None) + callback_url = getattr(response, "callback_url", None) + # This is already checked at the API level, but aanyway.. + if not key or not callback_url: + raise ValueError("Missing either key and/or callback_url!") + + try: + read_crypto_key_from_dict(json_decode(key)) + except Exception: + LOG.warn("Could not decode incoming SSO CLI request key") + raise ValueError( + "The provided key is invalid! It should be stackstorm-compatible AES key" + ) + + sso_request = self._create_sso_request(create_cli_sso_request, key=key) + response = router.Response(status=http_client.OK) + response.content_type = "application/json" + response.json = { + "sso_url": SSO_BACKEND.get_request_redirect_url( + sso_request.request_id, callback_url + ), + # this is needed because the db doesnt save microseconds + # pylint: disable=E1101 + "expiry": sso_request.expiry.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + + "000+00:00", + } + + return response + except NotImplementedError as e: + if sso_request: + sso_request.delete() + return process_failure_response(http_client.INTERNAL_SERVER_ERROR, e) + except Exception as e: + if sso_request: + sso_request.delete() raise e @@ -94,22 +239,52 @@ def get(self): CALLBACK_SUCCESS_RESPONSE_BODY = """ """ -def process_successful_authn_response(referer, token): - token_json = { +def token_to_json(token): + return { "id": str(token.id), "user": token.user, "token": token.token, @@ -118,6 +293,29 @@ def process_successful_authn_response(referer, token): "metadata": {}, } + +def process_successful_sso_cli_response(callback_url, key, token): + token_json = token_to_json(token) + + aes_key = read_crypto_key_from_dict(json_decode(key)) + encrypted_token = symmetric_encrypt(aes_key, json.dumps(token_json)) + + LOG.debug( + "Redirecting successfuly SSO CLI login to url [%s] " + "with extra parameters for the encrypted token", + callback_url, + ) + + # Response back to the browser has all the data in the query string, in an encrypted formta :) + resp = router.Response(status=http_client.FOUND) + resp.location = "%s?response=%s" % (callback_url, encrypted_token.decode("utf-8")) + + return resp + + +def process_successful_sso_web_response(referer, token): + token_json = token_to_json(token) + body = CALLBACK_SUCCESS_RESPONSE_BODY % referer resp = router.Response(body=body) resp.headers["Content-Type"] = "text/html" diff --git a/st2auth/st2auth/handlers.py b/st2auth/st2auth/handlers.py index f6540bcda7..81b55c94be 100644 --- a/st2auth/st2auth/handlers.py +++ b/st2auth/st2auth/handlers.py @@ -53,6 +53,39 @@ def handle_auth( ): raise NotImplementedError() + def sync_user_groups(self, extra, username, groups): + + if groups is None: + LOG.debug("No groups to sync for user '%s'", username) + return + + extra["username"] = username + extra["user_groups"] = groups + + LOG.debug( + 'Found "%s" groups for user "%s"' % (len(groups), username), + extra=extra, + ) + + user_db = UserDB(name=username) + + rbac_backend = get_rbac_backend() + syncer = rbac_backend.get_remote_group_to_role_syncer() + + try: + syncer.sync(user_db=user_db, groups=groups) + except Exception: + # Note: Failed sync is not fatal + LOG.exception( + 'Failed to synchronize remote groups for user "%s"' % (username), + extra=extra, + ) + else: + LOG.debug( + 'Successfully synchronized groups for user "%s"' % (username), + extra=extra, + ) + def _create_token_for_user(self, username, ttl=None): tokendb = create_token(username=username, ttl=ttl) return TokenAPI.from_model(tokendb) @@ -129,12 +162,21 @@ def handle_auth( ): remote_addr = headers.get("x-forwarded-for", remote_addr) extra = {"remote_addr": remote_addr} + LOG.debug( + "Authenticating for proxy with request [%s]", + getattr(request, "__dict__", None) if request else None, + ) if remote_user: ttl = getattr(request, "ttl", None) username = self._get_username_for_request(remote_user, request) try: token = self._create_token_for_user(username=username, ttl=ttl) + groups = getattr(request, "groups", None) + + if cfg.CONF.rbac.backend != "noop": + self.sync_user_groups(extra, username, groups) + except TTLTooLargeException as e: abort_request( status_code=http_client.BAD_REQUEST, message=six.text_type(e) @@ -228,33 +270,7 @@ def handle_auth( # No groups, return early return token - extra["username"] = username - extra["user_groups"] = user_groups - - LOG.debug( - 'Found "%s" groups for user "%s"' % (len(user_groups), username), - extra=extra, - ) - - user_db = UserDB(name=username) - - rbac_backend = get_rbac_backend() - syncer = rbac_backend.get_remote_group_to_role_syncer() - - try: - syncer.sync(user_db=user_db, groups=user_groups) - except Exception: - # Note: Failed sync is not fatal - LOG.exception( - 'Failed to synchronize remote groups for user "%s"' - % (username), - extra=extra, - ) - else: - LOG.debug( - 'Successfully synchronized groups for user "%s"' % (username), - extra=extra, - ) + self.sync_user_groups(extra, username, user_groups) return token return token diff --git a/st2auth/st2auth/sso/__init__.py b/st2auth/st2auth/sso/__init__.py index b6d0df930a..62403e2106 100644 --- a/st2auth/st2auth/sso/__init__.py +++ b/st2auth/st2auth/sso/__init__.py @@ -19,6 +19,7 @@ import traceback from oslo_config import cfg +from st2auth.sso.base import BaseSingleSignOnBackend from st2common import log as logging @@ -36,7 +37,7 @@ def get_available_backends(): return driver_loader.get_available_backends(namespace=BACKENDS_NAMESPACE) -def get_backend_instance(name): +def get_backend_instance(name) -> BaseSingleSignOnBackend: sso_backend_cls = driver_loader.get_backend_driver( namespace=BACKENDS_NAMESPACE, name=name ) @@ -69,7 +70,7 @@ def get_backend_instance(name): return sso_backend -def get_sso_backend(): +def get_sso_backend() -> BaseSingleSignOnBackend: """ Return SingleSignOnBackend class instance. """ diff --git a/st2auth/st2auth/sso/base.py b/st2auth/st2auth/sso/base.py index 5e11199818..970a798128 100644 --- a/st2auth/st2auth/sso/base.py +++ b/st2auth/st2auth/sso/base.py @@ -14,9 +14,38 @@ import abc import six +from typing import List -__all__ = ["BaseSingleSignOnBackend"] +__all__ = ["BaseSingleSignOnBackend", "BaseSingleSignOnBackendResponse"] + + +# This defines the expected response to be communicated back from verify_response methods +@six.add_metaclass(abc.ABCMeta) +class BaseSingleSignOnBackendResponse(object): + username: str = None + referer: str = None + groups: List[str] = None + + def __init__(self, username=None, referer=None, groups=[]): + self.username = username + self.groups = groups + self.referer = referer + + def __eq__(self, other): + if other is None: + return False + return ( + self.username == other.username + and self.groups == other.groups + and self.referer == other.referer + ) + + def __repr__(self): + return ( + f"BaseSingleSignOnBackendResponse(username={self.username}, groups={self.groups}" + + f", referer={self.referer}" + ) @six.add_metaclass(abc.ABCMeta) @@ -25,11 +54,18 @@ class BaseSingleSignOnBackend(object): Base single sign on authentication class. """ - def get_request_redirect_url(self, referer): + def get_request_redirect_url(self, referer) -> str: msg = 'The function "get_request_redirect_url" is not implemented in the base SSO backend.' raise NotImplementedError(msg) - def verify_response(self, response): + def get_request_id_from_response(self, response) -> str: + msg = ( + 'The function "get_request_id_from_response" is not implemented' + "in the base SSO backend." + ) + raise NotImplementedError(msg) + + def verify_response(self, response) -> BaseSingleSignOnBackendResponse: msg = ( 'The function "verify_response" is not implemented in the base SSO backend.' ) diff --git a/st2auth/st2auth/sso/noop.py b/st2auth/st2auth/sso/noop.py index 6699e084f3..55e764a809 100644 --- a/st2auth/st2auth/sso/noop.py +++ b/st2auth/st2auth/sso/noop.py @@ -21,7 +21,7 @@ NOT_IMPLEMENTED_MESSAGE = ( 'The default "noop" SSO backend is not a proper implementation. ' - "Please refer to the enterprise version for configuring SSO." + "Please configure SSO accordingly by selecting a proper backend." ) @@ -30,8 +30,11 @@ class NoOpSingleSignOnBackend(BaseSingleSignOnBackend): NoOp SSO authentication backend. """ - def get_request_redirect_url(self, referer): + def get_request_redirect_url(self, request_id, referer): raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) def verify_response(self, response): raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) + + def get_request_id_from_response(self, response) -> str: + raise NotImplementedError(NOT_IMPLEMENTED_MESSAGE) diff --git a/st2auth/tests/unit/controllers/v1/test_sso.py b/st2auth/tests/unit/controllers/v1/test_sso.py index 5596b0fb01..55e9e8b4eb 100644 --- a/st2auth/tests/unit/controllers/v1/test_sso.py +++ b/st2auth/tests/unit/controllers/v1/test_sso.py @@ -12,41 +12,88 @@ # See the License for the specific language governing permissions and # limitations under the License. -import st2tests.config as tests_config +# NOTE: We need to perform monkeypatch before importing ssl module otherwise tests will fail. +# See https://github.com/StackStorm/st2/pull/4834 for details +from st2common.util.monkey_patch import monkey_patch -tests_config.parse_args() +monkey_patch() -import json +from tests.base import FunctionalTest +from st2common.exceptions import auth as auth_exc +from st2auth.sso import noop +from st2auth.controllers.v1 import sso as sso_api_controller +from six.moves import urllib +from six.moves import http_client +from oslo_config import cfg import mock +import json +from typing import List +from st2auth.sso.base import BaseSingleSignOnBackendResponse +from st2common.models.db.auth import SSORequestDB +from st2common.persistence.auth import SSORequest, Token +from st2common.persistence.rbac import GroupToRoleMapping, UserRoleAssignment, Role +from st2common.models.db.rbac import GroupToRoleMappingDB, RoleDB +from st2common.services.access import ( + DEFAULT_SSO_REQUEST_TTL, + create_web_sso_request, + create_cli_sso_request, +) +import st2tests.config as tests_config +from st2common.util.crypto import read_crypto_key_from_dict, symmetric_decrypt -from oslo_config import cfg -from six.moves import http_client -from six.moves import urllib +from st2common.util import date as date_utils -from st2auth.controllers.v1 import sso as sso_api_controller -from st2auth.sso import noop -from st2common.exceptions import auth as auth_exc -from tests.base import FunctionalTest +tests_config.parse_args() SSO_V1_PATH = "/v1/sso" -SSO_REQUEST_V1_PATH = SSO_V1_PATH + "/request" +SSO_REQUEST_WEB_V1_PATH = SSO_V1_PATH + "/request/web" +SSO_REQUEST_CLI_V1_PATH = SSO_V1_PATH + "/request/cli" SSO_CALLBACK_V1_PATH = SSO_V1_PATH + "/callback" MOCK_REFERER = "https://127.0.0.1" MOCK_USER = "stanley" +MOCK_CALLBACK_URL = "http://localhost:34999" +MOCK_CLI_REQUEST_KEY = read_crypto_key_from_dict( + { + "hmacKey": { + "hmacKeyString": "-qdRklvhm4xvzIfaL6Z2nmQ-2N-c4IUtNa1_BowCVfg", + "size": 256, + }, + "aesKeyString": "0UyXFjBTQ9PMyHZ0mqrvuqCSzesuFup1d6m-4Vi3vdo", + "mode": "CBC", + "size": 256, + } +) +MOCK_CLI_REQUEST_KEY_ALTERNATIVE = read_crypto_key_from_dict( + { + "hmacKey": { + "hmacKeyString": "ENb-2COFGmdnshSnjjz3wePrxypVzCf9Jq2iuhXEgbc", + "size": 256, + }, + "aesKeyString": "8TpT_RaA6dlharswjqVlJSw027B60UkgnQqcgGfmf08", + "mode": "CBC", + "size": 256, + } +) +MOCK_CLI_REQUEST_KEY_JSON = MOCK_CLI_REQUEST_KEY.to_json() +MOCK_REQUEST_ID = "test-id" +MOCK_GROUPS = ["test", "test2"] +MOCK_VERIFIED_USER_OBJECT = BaseSingleSignOnBackendResponse( + referer=MOCK_REFERER, groups=MOCK_GROUPS, username=MOCK_USER +) class TestSingleSignOnController(FunctionalTest): def test_sso_enabled(self): cfg.CONF.set_override(group="auth", name="sso", override=True) response = self.app.get(SSO_V1_PATH, expect_errors=False) - self.assertTrue(response.status_code, http_client.OK) + self.assertEqual(response.status_code, http_client.OK) self.assertDictEqual(response.json, {"enabled": True}) def test_sso_disabled(self): cfg.CONF.set_override(group="auth", name="sso", override=False) response = self.app.get(SSO_V1_PATH, expect_errors=False) - self.assertTrue(response.status_code, http_client.OK) + self.assertEqual(response.status_code, http_client.OK) self.assertDictEqual(response.json, {"enabled": False}) @mock.patch.object( @@ -57,76 +104,282 @@ def test_sso_disabled(self): def test_unknown_exception(self): cfg.CONF.set_override(group="auth", name="sso", override=True) response = self.app.get(SSO_V1_PATH, expect_errors=False) - self.assertTrue(response.status_code, http_client.OK) + self.assertEqual(response.status_code, http_client.OK) self.assertDictEqual(response.json, {"enabled": False}) self.assertTrue( sso_api_controller.SingleSignOnController._get_sso_enabled_config.called ) +# Base SSO request test class, to be used by CLI/WEB class TestSingleSignOnRequestController(FunctionalTest): + + # + # Settupers + # + + # Cleanup sso requests + def setUp(self): + for x in SSORequest.get_all(): + SSORequest.delete(x) + + # + # Helpers + # + + def _assert_response(self, response, status_code, expected_body): + self.assertEqual(response.status_code, status_code) + self.assertDictEqual(response.json, expected_body) + + def _assert_sso_requests_len(self, expected): + sso_requests: List[SSORequestDB] = SSORequest.get_all() + self.assertEqual(len(sso_requests), expected) + return sso_requests + + def _assert_sso_request_success(self, sso_request, type): + self.assertEqual(sso_request.type, type) + self.assertLessEqual( + abs( + sso_request.expiry.timestamp() + - date_utils.get_datetime_utc_now().timestamp() + - DEFAULT_SSO_REQUEST_TTL + ), + 2, + ) + sso_api_controller.SSO_BACKEND.get_request_redirect_url.assert_called_with( + sso_request.request_id, MOCK_REFERER + ) + + def _test_cli_request_bad_parameter_helper(self, params, expected_error): + response = self._default_cli_request(params=params, expect_errors=True) + self._assert_response( + response, http_client.BAD_REQUEST, {"faultstring": expected_error} + ) + self._assert_sso_requests_len(0) + + def _default_web_request(self, expect_errors): + return self.app.get( + SSO_REQUEST_WEB_V1_PATH, + headers={"referer": MOCK_REFERER}, + expect_errors=expect_errors, + ) + + def _default_cli_request( + self, + params={"callback_url": MOCK_CALLBACK_URL, "key": MOCK_CLI_REQUEST_KEY_JSON}, + expect_errors=False, + ): + return self.app.post( + SSO_REQUEST_CLI_V1_PATH, + content_type="application/json", + params=json.dumps(params), + expect_errors=expect_errors, + ) + + # + # Tests :) + # + @mock.patch.object( sso_api_controller.SSO_BACKEND, "get_request_redirect_url", mock.MagicMock(side_effect=Exception("fooobar")), ) - def test_default_backend_unknown_exception(self): - expected_error = {"faultstring": "Internal Server Error"} - response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=True) - self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) - self.assertDictEqual(response.json, expected_error) + def test_web_default_backend_unknown_exception(self): + response = self._default_web_request(True) + self._assert_response( + response, + http_client.INTERNAL_SERVER_ERROR, + {"faultstring": "Internal Server Error"}, + ) + self._assert_sso_requests_len(0) - def test_default_backend_not_implemented(self): - expected_error = {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE} - response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=True) - self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) - self.assertDictEqual(response.json, expected_error) + def test_web_default_backend_invalid_key(self): + response = self._default_web_request(True) + self._assert_response( + response, + http_client.INTERNAL_SERVER_ERROR, + {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE}, + ) + self._assert_sso_requests_len(0) + + def test_web_default_backend_not_implemented(self): + response = self._default_web_request(True) + self._assert_response( + response, + http_client.INTERNAL_SERVER_ERROR, + {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE}, + ) + self._assert_sso_requests_len(0) @mock.patch.object( sso_api_controller.SSO_BACKEND, "get_request_redirect_url", mock.MagicMock(return_value="https://127.0.0.1"), ) - def test_idp_redirect(self): - response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=False) - self.assertTrue(response.status_code, http_client.TEMPORARY_REDIRECT) + def test_web_idp_redirect(self): + response = self._default_web_request(False) + self.assertEqual(response.status_code, http_client.TEMPORARY_REDIRECT) self.assertEqual(response.location, "https://127.0.0.1") + # Make sure we have created a SSO request based on this call :) + sso_requests = self._assert_sso_requests_len(1) + sso_request = sso_requests[0] + self._assert_sso_request_success(sso_request, SSORequestDB.Type.WEB) -class TestIdentityProviderCallbackController(FunctionalTest): @mock.patch.object( sso_api_controller.SSO_BACKEND, - "verify_response", + "get_request_redirect_url", mock.MagicMock(side_effect=Exception("fooobar")), ) - def test_default_backend_unknown_exception(self): - expected_error = {"faultstring": "Internal Server Error"} - response = self.app.post_json( - SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True + def test_cli_default_backend_unknown_exception(self): + response = self._default_cli_request(expect_errors=True) + self._assert_response( + response, + http_client.INTERNAL_SERVER_ERROR, + {"faultstring": "Internal Server Error"}, ) - self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) - self.assertDictEqual(response.json, expected_error) + self._assert_sso_requests_len(0) - def test_default_backend_not_implemented(self): - expected_error = {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE} - response = self.app.post_json( - SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True + def test_cli_default_backend_bad_key(self): + self._test_cli_request_bad_parameter_helper( + {"callback_url": MOCK_CALLBACK_URL, "key": "bad-key"}, + "The provided key is invalid! It should be stackstorm-compatible AES key", ) - self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) - self.assertDictEqual(response.json, expected_error) + + def test_cli_default_backend_missing_key(self): + self._test_cli_request_bad_parameter_helper( + { + "callback_url": MOCK_CALLBACK_URL, + }, + "'key' is a required property", + ) + + def test_cli_default_backend_missing_callback_url(self): + self._test_cli_request_bad_parameter_helper( + { + "key": MOCK_CLI_REQUEST_KEY_JSON, + }, + "'callback_url' is a required property", + ) + + def test_cli_default_backend_missing_key_and_callback_url(self): + self._test_cli_request_bad_parameter_helper( + {"ops": "ops"}, "'key' is a required property" + ) + + def test_cli_default_backend_not_implemented(self): + response = self._default_cli_request(expect_errors=True) + self._assert_response( + response, + http_client.INTERNAL_SERVER_ERROR, + {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE}, + ) + self._assert_sso_requests_len(0) @mock.patch.object( sso_api_controller.SSO_BACKEND, - "verify_response", - mock.MagicMock(return_value={"referer": MOCK_REFERER, "username": MOCK_USER}), + "get_request_redirect_url", + mock.MagicMock(return_value="https://127.0.0.1"), ) - def test_idp_callback(self): - expected_body = sso_api_controller.CALLBACK_SUCCESS_RESPONSE_BODY % MOCK_REFERER - response = self.app.post_json( - SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=False + def test_cli_default_backend(self): + response = self._default_cli_request( + params={"callback_url": MOCK_REFERER, "key": MOCK_CLI_REQUEST_KEY_JSON}, + expect_errors=False, + ) + + # Make sure we have created a SSO request based on this call :) + sso_requests = self._assert_sso_requests_len(1) + sso_request = sso_requests[0] + self._assert_sso_request_success(sso_request, SSORequestDB.Type.CLI) + self._assert_response( + response, + http_client.OK, + {"sso_url": "https://127.0.0.1", "expiry": sso_request.expiry.isoformat()}, ) - self.assertTrue(response.status_code, http_client.OK) - self.assertEqual(expected_body, response.body.decode("utf-8")) + + +class TestIdentityProviderCallbackController(FunctionalTest): + def setUp(self): + for x in SSORequest.get_all(): + SSORequest.delete(x) + + def setUp_for_rbac(self): + # Set up standard roles + for x in Role.get_all(): + Role.delete(x) + + RoleDB(name="system_admin", system=True).save() + RoleDB(name="admin", system=True).save() + RoleDB(name="my-test", system=True).save() + + # Cleanup user assignments + for x in UserRoleAssignment.get_all(): + UserRoleAssignment.delete(x) + + for x in GroupToRoleMapping.get_all(): + GroupToRoleMapping.delete(x) + + # Set up assignment mappings + GroupToRoleMappingDB( + group="test2", roles=["system_admin", "admin"], source="test", enabled=True + ).save() + + GroupToRoleMappingDB( + group="test", roles=["my-test"], source="test", enabled=True + ).save() + + cfg.CONF.set_override(group="rbac", name="enable", override=True) + cfg.CONF.set_override(group="rbac", name="backend", override="default") + + def tearDown_for_rbac(self): + + for x in UserRoleAssignment.get_all(): + UserRoleAssignment.delete(x) + + for x in GroupToRoleMapping.get_all(): + GroupToRoleMapping.delete(x) + + for x in Role.get_all(): + Role.delete(x) + + cfg.CONF.set_override(group="rbac", name="enable", override=False) + cfg.CONF.set_override(group="rbac", name="backend", override="default") + + # Helpers + # + + def _assert_response( + self, response, status_code, expected_body, response_type="json" + ): + self.assertEqual(response.status_code, status_code) + if response_type == "json": + self.assertDictEqual(response.json, expected_body) + else: + self.assertEqual(response.body.decode("utf-8"), expected_body) + + def _assert_sso_requests_len(self, expected): + sso_requests: List[SSORequestDB] = SSORequest.get_all() + self.assertEqual(len(sso_requests), expected) + return sso_requests + + def _assert_role_assignment_len(self, expected): + role_assignments: List[UserRoleAssignment] = UserRoleAssignment.get_all() + self.assertEqual(len(role_assignments), expected) + return role_assignments + + def _assert_token_data_is_valid(self, token_data): + self.assertEqual(token_data["user"], MOCK_USER) + self.assertIsNotNone(token_data["expiry"]) + self.assertIsNotNone(token_data["token"]) + + # Validate actual token :) + token = Token.get(token_data["token"]) + self.assertIsNotNone(token) + self.assertEqual(token.user, MOCK_USER) + self.assertEqual(token.expiry.isoformat()[0:19], token_data["expiry"][0:19]) + + def _assert_response_has_token_cookie_only(self, response): set_cookies_list = [h for h in response.headerlist if h[0] == "Set-Cookie"] self.assertEqual(len(set_cookies_list), 1) @@ -135,19 +388,269 @@ def test_idp_callback(self): cookie = urllib.parse.unquote(set_cookies_list[0][1]).split("=") st2_auth_token = json.loads(cookie[1].split(";")[0]) self.assertIn("token", st2_auth_token) - self.assertEqual(st2_auth_token["user"], MOCK_USER) + + return st2_auth_token + + def _default_callback_request(self, params={}, expect_errors=False): + return self.app.post_json( + SSO_CALLBACK_V1_PATH, params, expect_errors=expect_errors + ) + + # + # Tests + # + + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(side_effect=Exception("fooobar")), + ) + def test_default_backend_unknown_exception(self): + response = self._default_callback_request({"foo": "bar"}, expect_errors=True) + self._assert_response( + response, + http_client.INTERNAL_SERVER_ERROR, + {"faultstring": "Internal Server Error"}, + ) + + def test_default_backend_not_implemented(self): + response = self._default_callback_request({"foo": "bar"}, expect_errors=True) + self._assert_response( + response, + http_client.INTERNAL_SERVER_ERROR, + {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE}, + ) + + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(return_value=None), + ) + def test_default_backend_invalid_request_id(self): + response = self._default_callback_request({"foo": "bar"}, expect_errors=True) + self._assert_response( + response, + http_client.BAD_REQUEST, + {"faultstring": "Invalid request id coming from SAML response"}, + ) + + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(return_value=MOCK_REQUEST_ID), + ) + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "verify_response", + mock.MagicMock(return_value={"test": "user"}), + ) + def test_default_backend_invalid_backend_response(self): + create_web_sso_request(MOCK_REQUEST_ID) + response = self._default_callback_request({"foo": "bar"}, expect_errors=True) + self._assert_response( + response, + http_client.INTERNAL_SERVER_ERROR, + { + "faultstring": ( + "Unexpected SSO backend response type." + " Expected BaseSingleSignOnBackendResponse instance!" + ) + }, + ) + + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(return_value=MOCK_REQUEST_ID), + ) + def test_idp_callback_missing_sso_request(self): + self._assert_sso_requests_len(0) + response = self._default_callback_request({"foo": "bar"}, expect_errors=True) + + self._assert_response( + response, + http_client.BAD_REQUEST, + { + "faultstring": "This SSO request is invalid (it may have already been used)" + }, + ) + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(return_value=MOCK_REQUEST_ID), + ) + def test_idp_callback_sso_request_expired(self): + # given + # Create fake expired request + create_web_sso_request(MOCK_REQUEST_ID, -20) + self._assert_sso_requests_len(1) + response = self._default_callback_request({"foo": "bar"}, expect_errors=True) + + self._assert_response( + response, + http_client.BAD_REQUEST, + { + "faultstring": "The SSO request associated with this response has already expired!" + }, + ) + + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(return_value=MOCK_REQUEST_ID), + ) @mock.patch.object( sso_api_controller.SSO_BACKEND, "verify_response", - mock.MagicMock(return_value={"referer": MOCK_REFERER, "username": MOCK_USER}), + mock.MagicMock(return_value=MOCK_VERIFIED_USER_OBJECT), + ) + def _test_idp_callback_web(self): + # given + # Create fake request + create_web_sso_request(MOCK_REQUEST_ID) + self._assert_sso_requests_len(1) + + # when + # Callback based onthe fake request :) -- as mocked above + response = self._default_callback_request({"foo": "bar"}, expect_errors=False) + + # then + # Validate request has been processed and response is as expected + self._assert_sso_requests_len(0) + self._assert_response( + response, + http_client.OK, + sso_api_controller.CALLBACK_SUCCESS_RESPONSE_BODY % MOCK_REFERER, + "str", + ) + + # Validate token is valid + token_data = self._assert_response_has_token_cookie_only(response) + self._assert_token_data_is_valid(token_data) + + def test_idp_callback_web_without_rbac(self): + self._assert_role_assignment_len(0) + self._test_idp_callback_web() + self._assert_role_assignment_len(0) + + def test_idp_callback_web_with_rbac(self): + self.setUp_for_rbac() + self._assert_role_assignment_len(0) + + self._test_idp_callback_web() + + self._assert_role_assignment_len(3) + self.tearDown_for_rbac() + + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(return_value=MOCK_REQUEST_ID), + ) + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "verify_response", + mock.MagicMock(return_value=MOCK_VERIFIED_USER_OBJECT), + ) + def _test_idp_callback_cli(self): + # given + # Create fake request + create_cli_sso_request(MOCK_REQUEST_ID, MOCK_CLI_REQUEST_KEY_JSON) + self._assert_sso_requests_len(1) + + # when + # Callback based onthe fake request :) -- as mocked above + response = self._default_callback_request({"foo": "bar"}, expect_errors=False) + + # then + # Validate request has been processed and response is as expected + self._assert_sso_requests_len(0) + self.assertEqual(response.status_code, http_client.FOUND) + self.assertRegex( + response.location, "^" + MOCK_REFERER + r"\?response=[A-Z0-9]+$" + ) + + # decrypt token + encrypted_response = response.location.split("response=")[1] + token_data_json = symmetric_decrypt(MOCK_CLI_REQUEST_KEY, encrypted_response) + self.assertIsNotNone(token_data_json) + + # Validate token is valid + token_data = json.loads(token_data_json) + self._assert_token_data_is_valid(token_data) + + def test_idp_callback_cli_without_rbac(self): + self._assert_role_assignment_len(0) + self._test_idp_callback_cli() + self._assert_role_assignment_len(0) + + def test_idp_callback_cli_with_rbac(self): + self.setUp_for_rbac() + self._assert_role_assignment_len(0) + + self._test_idp_callback_cli() + + self._assert_role_assignment_len(3) + self.tearDown_for_rbac() + + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(return_value=MOCK_REQUEST_ID), + ) + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "verify_response", + mock.MagicMock(return_value=MOCK_VERIFIED_USER_OBJECT), + ) + def test_idp_callback_cli_invalid_decryption_key(self): + # given + # Create fake request + create_cli_sso_request(MOCK_REQUEST_ID, MOCK_CLI_REQUEST_KEY_JSON) + self._assert_sso_requests_len(1) + self._assert_role_assignment_len(0) + + # when + # Callback based onthe fake request :) -- as mocked above + response = self._default_callback_request({"foo": "bar"}, expect_errors=False) + + # then + # Validate request has been processed and response is as expected + self._assert_sso_requests_len(0) + self._assert_role_assignment_len(0) + self.assertEqual(response.status_code, http_client.FOUND) + self.assertRegex( + response.location, "^" + MOCK_REFERER + r"\?response=[A-Z0-9]+$" + ) + + # decrypt token + encrypted_response = response.location.split("response=")[1] + with self.assertRaises(Exception): + symmetric_decrypt(MOCK_CLI_REQUEST_KEY_ALTERNATIVE, encrypted_response) + + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(return_value=MOCK_REQUEST_ID), + ) + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "verify_response", + mock.MagicMock(return_value=MOCK_VERIFIED_USER_OBJECT), ) def test_callback_url_encoded_payload(self): + create_web_sso_request(MOCK_REQUEST_ID) data = {"foo": ["bar"]} headers = {"Content-Type": "application/x-www-form-urlencoded"} response = self.app.post(SSO_CALLBACK_V1_PATH, data, headers=headers) - self.assertTrue(response.status_code, http_client.OK) + self.assertEqual(response.status_code, http_client.OK) + @mock.patch.object( + sso_api_controller.SSO_BACKEND, + "get_request_id_from_response", + mock.MagicMock(return_value=MOCK_REQUEST_ID), + ) @mock.patch.object( sso_api_controller.SSO_BACKEND, "verify_response", @@ -156,9 +659,10 @@ def test_callback_url_encoded_payload(self): ), ) def test_idp_callback_verification_failed(self): + create_web_sso_request(MOCK_REQUEST_ID) expected_error = {"faultstring": "Verification Failed"} response = self.app.post_json( SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True ) - self.assertTrue(response.status_code, http_client.UNAUTHORIZED) + self.assertEqual(response.status_code, http_client.UNAUTHORIZED) self.assertDictEqual(response.json, expected_error) diff --git a/st2auth/tests/unit/controllers/v1/test_token.py b/st2auth/tests/unit/controllers/v1/test_token.py index e56c7e9acb..445a9ad998 100644 --- a/st2auth/tests/unit/controllers/v1/test_token.py +++ b/st2auth/tests/unit/controllers/v1/test_token.py @@ -26,6 +26,12 @@ import mock from oslo_config import cfg +# NOTE: We need to perform monkeypatch before importing ssl module otherwise tests will fail. +# See https://github.com/StackStorm/st2/pull/4834 for details +from st2common.util.monkey_patch import monkey_patch + +monkey_patch() + from tests.base import FunctionalTest from st2common.util import isotime from st2common.util import date as date_utils diff --git a/st2auth/tests/unit/test_handlers.py b/st2auth/tests/unit/test_handlers.py index cf00e642a6..24e0012dc1 100644 --- a/st2auth/tests/unit/test_handlers.py +++ b/st2auth/tests/unit/test_handlers.py @@ -30,23 +30,125 @@ from st2tests.mocks.auth import MockRequest from st2tests.mocks.auth import get_mock_backend +from st2common.persistence.rbac import UserRoleAssignment +from st2common.models.db.rbac import GroupToRoleMappingDB, RoleDB + + __all__ = ["AuthHandlerTestCase"] +from st2common.router import GenericRequestParam + +MOCK_USER = "test_proxy_handler" + @mock.patch("st2auth.handlers.get_auth_backend_instance", get_mock_backend) -class AuthHandlerTestCase(CleanDbTestCase): +class ProxyHandlerRBACAndGroupsTestCase(CleanDbTestCase): + def _assert_roles_len(self, user, total): + user_roles = UserRoleAssignment.get_all(user=user) + self.assertEqual(len(user_roles), total) + return user_roles + def setUp(self): - super(AuthHandlerTestCase, self).setUp() + super(ProxyHandlerRBACAndGroupsTestCase, self).setUp() cfg.CONF.auth.backend = "mock" - def test_proxy_handler(self): + # Create test roles + RoleDB(name="role-1").save() + RoleDB(name="role-2").save() + + # Create tsts mappings + GroupToRoleMappingDB( + group="group-1", roles=["role-1"], source="test", enabled=True + ).save() + + GroupToRoleMappingDB( + group="group-2", roles=["role-2"], source="test", enabled=True + ).save() + + cfg.CONF.set_override(name="enable", group="rbac", override=False) + cfg.CONF.set_override(name="backend", group="rbac", override="noop") + + def test_proxy_handler_no_groups_no_rbac(self): h = handlers.ProxyAuthHandler() request = {} token = h.handle_auth( - request, headers={}, remote_addr=None, remote_user="test_proxy_handler" + request, headers={}, remote_addr=None, remote_user=MOCK_USER + ) + self._assert_roles_len(token.user, 0) + self.assertEqual(token.user, MOCK_USER) + + def test_proxy_handler_with_groups_and_rbac_disabled(self): + + h = handlers.ProxyAuthHandler() + + request = GenericRequestParam(groups=["group-1", "group-2"]) + token = h.handle_auth( + request, headers={}, remote_addr=None, remote_user=MOCK_USER ) - self.assertEqual(token.user, "test_proxy_handler") + self._assert_roles_len(token.user, 0) + + self.assertEqual(token.user, MOCK_USER) + + def test_proxy_handler_with_groups_and_rbac_enabled(self): + + cfg.CONF.set_override(name="enable", group="rbac", override=True) + cfg.CONF.set_override(name="backend", group="rbac", override="default") + + h = handlers.ProxyAuthHandler() + + request = GenericRequestParam(groups=["group-1", "group-2"]) + token = h.handle_auth( + request, headers={}, remote_addr=None, remote_user=MOCK_USER + ) + + self.assertEqual(token.user, MOCK_USER) + user_roles = self._assert_roles_len(token.user, 2) + self.assertEqual(user_roles[0].role, "role-1") + self.assertEqual(user_roles[1].role, "role-2") + + def test_proxy_handler_no_groups_and_rbac_enabled_with_no_prior_roles(self): + + cfg.CONF.set_override(name="enable", group="rbac", override=True) + cfg.CONF.set_override(name="backend", group="rbac", override="default") + + h = handlers.ProxyAuthHandler() + + request = GenericRequestParam(groups=[]) + token = h.handle_auth( + request, headers={}, remote_addr=None, remote_user=MOCK_USER + ) + user_roles = self._assert_roles_len(token.user, 0) + + self.assertEqual(token.user, MOCK_USER) + self.assertEqual(len(user_roles), 0) + + def test_proxy_handler_no_groups_and_rbac_enabled_with_prior_roles(self): + + self.test_proxy_handler_with_groups_and_rbac_enabled() + self._assert_roles_len(MOCK_USER, 2) + + cfg.CONF.set_override(name="enable", group="rbac", override=True) + cfg.CONF.set_override(name="backend", group="rbac", override="default") + + h = handlers.ProxyAuthHandler() + + request = GenericRequestParam(groups=[]) + token = h.handle_auth( + request, headers={}, remote_addr=None, remote_user=MOCK_USER + ) + user_roles = UserRoleAssignment.get_all(user=token.user) + + self.assertEqual(token.user, MOCK_USER) + self.assertEqual(len(user_roles), 0) + + +@mock.patch("st2auth.handlers.get_auth_backend_instance", get_mock_backend) +class AuthHandlerTestCase(CleanDbTestCase): + def setUp(self): + super(AuthHandlerTestCase, self).setUp() + + cfg.CONF.auth.backend = "mock" def test_standalone_bad_auth_type(self): h = handlers.StandaloneAuthHandler() diff --git a/st2client/st2client/client.py b/st2client/st2client/client.py index 4b3e74e849..318455f375 100644 --- a/st2client/st2client/client.py +++ b/st2client/st2client/client.py @@ -23,7 +23,7 @@ from st2client import models from st2client.utils import httpclient -from st2client.models.core import ResourceManager +from st2client.models.core import ResourceManager, TokenResourceManager from st2client.models.core import ActionAliasResourceManager from st2client.models.core import ActionAliasExecutionManager from st2client.models.core import ActionResourceManager @@ -145,7 +145,7 @@ def __init__( # Instantiate resource managers and assign appropriate API endpoint. self.managers = dict() - self.managers["Token"] = ResourceManager( + self.managers["Token"] = TokenResourceManager( models.Token, self.endpoints["auth"], cacert=self.cacert, diff --git a/st2client/st2client/commands/auth.py b/st2client/st2client/commands/auth.py index 71437a1751..e8abfdd99a 100644 --- a/st2client/st2client/commands/auth.py +++ b/st2client/st2client/commands/auth.py @@ -19,7 +19,7 @@ import json import logging import os - +import webbrowser import requests import six from six.moves.configparser import ConfigParser @@ -32,11 +32,16 @@ from st2client.commands.noop import NoopCommand from st2client.exceptions.operations import OperationFailureException from st2client.formatters import table +from st2client.utils.date import format_isodate_for_user_timezone LOG = logging.getLogger(__name__) +class MissingUserNameException(Exception): + pass + + class TokenCreateCommand(resource.ResourceCommand): display_attributes = ["user", "token", "expiry"] @@ -118,8 +123,36 @@ def __init__(self, resource, *args, **kwargs): **kwargs, ) - self.parser.add_argument("username", help="Name of the user to authenticate.") + self.parser.add_argument( + "username", + nargs="?", + default=None, + help="Name of the user to authenticate (not needed if --sso is used).", + ) + self.parser.add_argument( + "-s", + "--sso", + dest="sso", + action="store_true", + help="Whether to use SSO authentication or not. " + "If chosen, bypasses username/password.", + ) + self.parser.add_argument( + "-P", + "--sso-port", + dest="sso_port", + type=int, + default=0, + help="Fixed SSO port to use for local callback server. Default is 0, which is random", + ) + self.parser.add_argument( + "--no-sso-browser", + dest="no_sso_browser", + action="store_true", + default=False, + help="Prevents from automatically launching the browser for SSO", + ) self.parser.add_argument( "-p", "--password", @@ -143,13 +176,11 @@ def __init__(self, resource, *args, **kwargs): default=False, dest="write_password", help="Write the password in plain text to the config file " - "(default is to omit it)", + "(only applicable to username/password login, and default is to omit it)", ) def run(self, args, **kwargs): - if not args.password: - args.password = getpass.getpass() instance = self.resource(ttl=args.ttl) if args.ttl else self.resource() cli = BaseCLIApp() @@ -161,11 +192,47 @@ def run(self, args, **kwargs): # config file not found in args or in env, defaulting config_file = config_parser.ST2_CONFIG_PATH - # Retrieve token - manager = self.manager.create( - instance, auth=(args.username, args.password), **kwargs - ) - cli._cache_auth_token(token_obj=manager) + # Retrieve token based on whether we're using SSO or username/password login :) + if args.sso: + LOG.debug("Logging in with SSO with fixed port [%d]", args.sso_port) + # Retrieve token from SSO backend + sso_proxy = self.manager.create_sso_request(args.sso_port, **kwargs) + + if args.no_sso_browser: + print( + "Please finish your SSO login by visiting: %s" + % (sso_proxy.get_proxy_url()) + ) + else: + print( + "Please finish the SSO login on your browser.\n" + "If the browser hasn't opened automatically, please visit: %s" + % (sso_proxy.get_proxy_url()) + ) + webbrowser.open(sso_proxy.get_proxy_url()) + + try: + token = self.manager.wait_for_sso_token(sso_proxy) + except KeyboardInterrupt: + raise Exception("SSO Login aborted by user") + + # Defaults to username/password if not SSO + else: + LOG.debug("Logging in with username/password") + if not args.username: + raise MissingUserNameException( + "Username expected when not using SSO login" + ) + + if not args.password: + args.password = getpass.getpass() + + # Retrieve token from username/password auth api + token = self.manager.create( + instance, auth=(args.username, args.password), **kwargs + ) + + cli._cache_auth_token(token_obj=token) # Update existing configuration with new credentials config = ConfigParser() @@ -175,8 +242,9 @@ def run(self, args, **kwargs): if not config.has_section("credentials"): config.add_section("credentials") - config.set("credentials", "username", args.username) - if args.write_password: + config.set("credentials", "username", token.user) + + if args.write_password and not args.sso: config.set("credentials", "password", args.password) else: # Remove any existing password from config @@ -189,37 +257,40 @@ def run(self, args, **kwargs): if not config_existed: os.chmod(config_file, 0o660) - return manager + return token def run_and_print(self, args, **kwargs): + try: - self.run(args, **kwargs) + token = self.run(args, **kwargs) + formatted_expiry = format_isodate_for_user_timezone(token.expiry) + print("Logged in as %s until %s" % (token.user, formatted_expiry)) + + if not args.write_password and not args.sso: + print("") + print( + "Note: You didn't use --write-password option so the password hasn't been " + "stored in the client config and you will need to login again after %s hours when " + "the auth token expires." % (formatted_expiry) + ) + print( + 'As an alternative, you can run st2 login command with the "--write-password" ' + "flag, but keep it mind this will cause it to store the password in plain-text " + "in the client config file (~/.st2/config)." + ) + except MissingUserNameException as e: + raise e except Exception as e: if self.app.client.debug: raise - raise Exception( - "Failed to log in as %s: %s" % (args.username, six.text_type(e)) - ) - - print("Logged in as %s" % (args.username)) + if args.sso: + raise Exception("Could not perform SSO login: %s" % (six.text_type(e))) - if not args.write_password: - # Note: Client can't depend and import from common so we need to hard-code this - # default value - token_expire_hours = 24 - - print("") - print( - "Note: You didn't use --write-password option so the password hasn't been " - "stored in the client config and you will need to login again in %s hours when " - "the auth token expires." % (token_expire_hours) - ) - print( - 'As an alternative, you can run st2 login command with the "--write-password" ' - "flag, but keep it mind this will cause it to store the password in plain-text " - "in the client config file (~/.st2/config)." - ) + else: + raise Exception( + "Failed to log in as %s: %s" % (args.username, six.text_type(e)) + ) class WhoamiCommand(resource.ResourceCommand): diff --git a/st2client/st2client/models/core.py b/st2client/st2client/models/core.py index e66e0e7800..b3be1764db 100644 --- a/st2client/st2client/models/core.py +++ b/st2client/st2client/models/core.py @@ -28,8 +28,11 @@ from six.moves import http_client import requests -from st2client.utils import httpclient +from st2client.utils.crypto import AESKey + +from st2client.utils import httpclient +from st2client.utils import sso_interceptor LOG = logging.getLogger(__name__) @@ -881,6 +884,53 @@ def list(self, group_id, **kwargs): return result +class TokenResourceManager(ResourceManager): + + # This will spin up a local web server to mediate the requests from/to the sso + # endpoint, so that we can intercept the callback and token :) + # + # This function will not retrieve the token directly because we still need + # to print out some interaction with the user and that's best done elsewhere, so + # we'll just provide back the "interceptor" object, that is able to provide + # the URL and wait for the token to be ready :) + def create_sso_request( + self, sso_port=0, **kwargs + ) -> sso_interceptor.SSOInterceptorProxy: + url = "/sso/request/cli" + + key = AESKey.generate() + sso_proxy = sso_interceptor.SSOInterceptorProxy(key, sso_port) + + response = self.client.post( + url, + {"key": key.to_json(), "callback_url": sso_proxy.get_callback_url()}, + **kwargs, + ) + + if response.status_code != http_client.OK: + self.handle_error(response) + + json_response = response.json() + if not type(json_response) is dict: + raise ValueError( + "Expected response body from SSO CLI request, but couldn't find one :( " + ) + + sso_url = response.json().get("sso_url", None) + if sso_url is None: + raise ValueError( + "Expected SSO URL to be present in SSO login request response!" + ) + + sso_proxy.set_sso_url(sso_url) + + LOG.debug("Received SSO URL with lenght %d", len(sso_url)) + return sso_proxy + + def wait_for_sso_token(self, sso_proxy): + return self.resource.deserialize(sso_proxy.get_token()) + + class KeyValuePairResourceManager(ResourceManager): @add_auth_token_to_kwargs_from_env def get_by_name(self, name, **kwargs): diff --git a/st2client/st2client/utils/crypto.py b/st2client/st2client/utils/crypto.py new file mode 100644 index 0000000000..138e6b165f --- /dev/null +++ b/st2client/st2client/utils/crypto.py @@ -0,0 +1,503 @@ +# Copyright 2020 The StackStorm Authors. +# Copyright 2019 Extreme Networks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for handling symmetric encryption and decryption of short text values (mostly used for +encrypted datastore values aka secrets). + +NOTE: In the past, this module used and relied on keyczar, but since keyczar doesn't support +Python 3, we moved to cryptography library. + +symmetric_encrypt and symmetric_decrypt functions except values as returned by the AESKey.Encrypt() +and AESKey.Decrypt() methods in keyczar. Those functions follow the same approach (AES in CBC mode +with SHA1 HMAC signature) as keyczar methods, but they use and rely on primitives and methods from +the cryptography library. + +This was done to make the keyczar -> cryptography migration fully backward compatible. + +Eventually, we should move to Fernet (https://cryptography.io/en/latest/fernet/) recipe for +symmetric encryption / decryption, because it offers more robustness and safer defaults (SHA256 +instead of SHA1, etc.). +""" + +from __future__ import absolute_import + +import os +import binascii +import base64 + +from hashlib import sha1 + +import six + +from cryptography.hazmat.primitives.ciphers import Cipher +from cryptography.hazmat.primitives.ciphers import algorithms +from cryptography.hazmat.primitives.ciphers import modes +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import hmac +from cryptography.hazmat.backends import default_backend + +from st2client.utils.jsonify import json_encode +from st2client.utils.jsonify import json_decode + +__all__ = [ + "KEYCZAR_HEADER_SIZE", + "KEYCZAR_AES_BLOCK_SIZE", + "KEYCZAR_HLEN", + "read_crypto_key", + "symmetric_encrypt", + "symmetric_decrypt", + "cryptography_symmetric_encrypt", + "cryptography_symmetric_decrypt", + # NOTE: Keyczar functions are here for testing reasons - they are only used by tests + "keyczar_symmetric_encrypt", + "keyczar_symmetric_decrypt", + "AESKey", +] + +# Keyczar related constants +KEYCZAR_HEADER_SIZE = 5 +KEYCZAR_AES_BLOCK_SIZE = 16 +KEYCZAR_HLEN = sha1().digest_size + +# Minimum key size which can be used for symmetric crypto +MINIMUM_AES_KEY_SIZE = 128 + +DEFAULT_AES_KEY_SIZE = 256 + +if DEFAULT_AES_KEY_SIZE < MINIMUM_AES_KEY_SIZE: + raise ValueError( + 'AES key size "%s" is smaller than minimun key size "%s".' + % (DEFAULT_AES_KEY_SIZE, MINIMUM_AES_KEY_SIZE) + ) + + +class AESKey(object): + """ + Class representing AES key object. + """ + + aes_key_string = None + hmac_key_string = None + hmac_key_size = None + mode = None + size = None + + def __init__( + self, + aes_key_string, + hmac_key_string, + hmac_key_size, + mode="CBC", + size=DEFAULT_AES_KEY_SIZE, + ): + if mode not in ["CBC"]: + raise ValueError("Unsupported mode: %s" % (mode)) + + if size < MINIMUM_AES_KEY_SIZE: + raise ValueError("Unsafe key size: %s" % (size)) + + self.aes_key_string = aes_key_string + self.hmac_key_string = hmac_key_string + self.hmac_key_size = int(hmac_key_size) + self.mode = mode.upper() + self.size = int(size) + + # We also store bytes version of the key since bytes are needed by encrypt and decrypt + # methods + self.hmac_key_bytes = Base64WSDecode(self.hmac_key_string) + self.aes_key_bytes = Base64WSDecode(self.aes_key_string) + + @classmethod + def generate(self, key_size=DEFAULT_AES_KEY_SIZE): + """ + Generate a new AES key with the corresponding HMAC key. + + :rtype: :class:`AESKey` + """ + if key_size < MINIMUM_AES_KEY_SIZE: + raise ValueError("Unsafe key size: %s" % (key_size)) + + aes_key_bytes = os.urandom(int(key_size / 8)) + aes_key_string = Base64WSEncode(aes_key_bytes) + + hmac_key_bytes = os.urandom(int(key_size / 8)) + hmac_key_string = Base64WSEncode(hmac_key_bytes) + + return AESKey( + aes_key_string=aes_key_string, + hmac_key_string=hmac_key_string, + hmac_key_size=key_size, + mode="CBC", + size=key_size, + ) + + def to_json(self): + """ + Return JSON representation of this key which is fully compatible with keyczar JSON key + file format. + + :rtype: ``str`` + """ + data = { + "hmacKey": { + "hmacKeyString": self.hmac_key_string, + "size": self.hmac_key_size, + }, + "aesKeyString": self.aes_key_string, + "mode": self.mode.upper(), + "size": int(self.size), + } + return json_encode(data) + + def __repr__(self): + return "" % ( + self.hmac_key_size, + self.mode, + self.size, + ) + + +def read_crypto_key(key_path): + """ + Read crypto key from keyczar JSON key file format and return parsed AESKey object. + + :param key_path: Absolute path to file containing crypto key in Keyczar JSON format. + :type key_path: ``str`` + + :rtype: :class:`AESKey` + """ + with open(key_path, "r") as fp: + content = fp.read() + + content = json_decode(content) + + try: + return read_crypto_key_from_dict(content) + except KeyError as e: + msg = 'Invalid or malformed key file "%s": %s' % (key_path, six.text_type(e)) + raise KeyError(msg) + + +def read_crypto_key_from_dict(key_dict): + """ + Read crypto key from provided Keyczar JSON-format dict and return parsed AESKey object. + + :param key_dict: A dictionary with a key in Keyczar format (same keys as the JSON). + :type key_dict: ``dict`` + + :rtype: :class:`AESKey` + """ + + try: + aes_key = AESKey( + aes_key_string=key_dict["aesKeyString"], + hmac_key_string=key_dict["hmacKey"]["hmacKeyString"], + hmac_key_size=key_dict["hmacKey"]["size"], + mode=key_dict["mode"].upper(), + size=key_dict["size"], + ) + except KeyError as e: + msg = "Invalid or malformed AES key dictionary: %s" % (six.text_type(e)) + raise KeyError(msg) + + return aes_key + + +def symmetric_encrypt(encrypt_key, plaintext): + return cryptography_symmetric_encrypt(encrypt_key=encrypt_key, plaintext=plaintext) + + +def symmetric_decrypt(decrypt_key, ciphertext): + return cryptography_symmetric_decrypt( + decrypt_key=decrypt_key, ciphertext=ciphertext + ) + + +def cryptography_symmetric_encrypt(encrypt_key, plaintext): + """ + Encrypt the provided plaintext using AES encryption. + + NOTE 1: This function return a string which is fully compatible with Keyczar.Encrypt() method. + + NOTE 2: This function is loosely based on keyczar AESKey.Encrypt() (Apache 2.0 license). + + The final encrypted string value consists of: + + [message bytes][HMAC signature bytes for the message] where message consists of + [keyczar header plaintext][IV bytes][ciphertext bytes] + + NOTE: Header itself is unused, but it's added so the format is compatible with keyczar format. + + """ + if not isinstance(encrypt_key, AESKey): + raise TypeError( + "Encrypted key needs to be an AESkey class instance" + f" (was {type(encrypt_key)})." + ) + if not isinstance(plaintext, (six.text_type, six.string_types, six.binary_type)): + raise TypeError( + "Plaintext needs to either be a string/unicode or bytes" + f" (was {type(plaintext)})." + ) + + aes_key_bytes = encrypt_key.aes_key_bytes + hmac_key_bytes = encrypt_key.hmac_key_bytes + + if not isinstance(aes_key_bytes, six.binary_type): + raise TypeError(f"AESKey is not bytes (it is {type(aes_key_bytes)}).") + if not isinstance(hmac_key_bytes, six.binary_type): + raise TypeError(f"HMACKey is not bytes (it is {type(hmac_key_bytes)}).") + + if isinstance(plaintext, (six.text_type, six.string_types)): + # Convert data to bytes + data = plaintext.encode("utf-8") + else: + data = plaintext + + # Pad data + data = pkcs5_pad(data) + + # Generate IV + iv_bytes = os.urandom(KEYCZAR_AES_BLOCK_SIZE) + + backend = default_backend() + cipher = Cipher(algorithms.AES(aes_key_bytes), modes.CBC(iv_bytes), backend=backend) + encryptor = cipher.encryptor() + + # NOTE: We don't care about actual Keyczar header value, we only care about the length (5 + # bytes) so we simply add 5 0's + header_bytes = b"00000" + + ciphertext_bytes = encryptor.update(data) + encryptor.finalize() + msg_bytes = header_bytes + iv_bytes + ciphertext_bytes + + # Generate HMAC signature for the message (header + IV + ciphertext) + h = hmac.HMAC(hmac_key_bytes, hashes.SHA1(), backend=backend) + h.update(msg_bytes) + sig_bytes = h.finalize() + + result = msg_bytes + sig_bytes + + # Convert resulting byte string to hex notation ASCII string + result = binascii.hexlify(result).upper() + + return result + + +def cryptography_symmetric_decrypt(decrypt_key, ciphertext): + """ + Decrypt the provided ciphertext which has been encrypted using symmetric_encrypt() method (it + assumes input is in hex notation as returned by binascii.hexlify). + + NOTE 1: This function assumes ciphertext has been encrypted using symmetric AES crypto from + keyczar library. Underneath it uses crypto primitives from cryptography library which is Python + 3 compatible. + + NOTE 2: This function is loosely based on keyczar AESKey.Decrypt() (Apache 2.0 license). + """ + if not isinstance(decrypt_key, AESKey): + raise TypeError( + "Decrypted key needs to be an AESKey class instance" + f" (was {type(decrypt_key)})." + ) + if not isinstance(ciphertext, (six.text_type, six.string_types, six.binary_type)): + raise TypeError( + "Ciphertext needs to either be a string/unicode or bytes" + f" (was {type(ciphertext)})." + ) + aes_key_bytes = decrypt_key.aes_key_bytes + hmac_key_bytes = decrypt_key.hmac_key_bytes + + if not isinstance(aes_key_bytes, six.binary_type): + raise TypeError(f"AESKey is not bytes (it is {type(aes_key_bytes)}).") + if not isinstance(hmac_key_bytes, six.binary_type): + raise TypeError(f"HMACKey is not bytes (it is {type(hmac_key_bytes)}).") + + # Convert from hex notation ASCII string to bytes + ciphertext = binascii.unhexlify(ciphertext) + + data_bytes = ciphertext[KEYCZAR_HEADER_SIZE:] # remove header + + # Verify ciphertext contains IV + HMAC signature + if len(data_bytes) < (KEYCZAR_AES_BLOCK_SIZE + KEYCZAR_HLEN): + raise ValueError("Invalid or malformed ciphertext (too short)") + + iv_bytes = data_bytes[:KEYCZAR_AES_BLOCK_SIZE] # first block is IV + ciphertext_bytes = data_bytes[ + KEYCZAR_AES_BLOCK_SIZE:-KEYCZAR_HLEN + ] # strip IV and signature + signature_bytes = data_bytes[-KEYCZAR_HLEN:] # last 20 bytes are signature + + # Verify HMAC signature + backend = default_backend() + h = hmac.HMAC(hmac_key_bytes, hashes.SHA1(), backend=backend) + h.update(ciphertext[:-KEYCZAR_HLEN]) + h.verify(signature_bytes) + + # Decrypt ciphertext + cipher = Cipher(algorithms.AES(aes_key_bytes), modes.CBC(iv_bytes), backend=backend) + + decryptor = cipher.decryptor() + decrypted = decryptor.update(ciphertext_bytes) + decryptor.finalize() + + # Unpad + decrypted = pkcs5_unpad(decrypted) + return decrypted + + +### +# NOTE: Those methods below are deprecated and only used for testing purposes +## + + +def keyczar_symmetric_encrypt(encrypt_key, plaintext): + """ + Encrypt the given message using the encrypt_key. Returns a UTF-8 str + ready to be stored in database. Note that we convert the hex notation + to a ASCII notation to produce a UTF-8 friendly string. + + Also, this method will not return the same output on multiple invocations + of same method. The reason is that the Encrypt method uses a different + 'Initialization Vector' per run and the IV is part of the output. + + :param encrypt_key: Symmetric AES key to use for encryption. + :type encrypt_key: :class:`AESKey` + + :param plaintext: Plaintext / message to be encrypted. + :type plaintext: ``str`` + + :rtype: ``str`` + """ + from keyczar.keys import AesKey as KeyczarAesKey # pylint: disable=import-error + from keyczar.keys import HmacKey as KeyczarHmacKey # pylint: disable=import-error + from keyczar.keyinfo import GetMode # pylint: disable=import-error + + encrypt_key = KeyczarAesKey( + encrypt_key.aes_key_string, + KeyczarHmacKey(encrypt_key.hmac_key_string, encrypt_key.hmac_key_size), + encrypt_key.size, + GetMode(encrypt_key.mode), + ) + + return binascii.hexlify(encrypt_key.Encrypt(plaintext)).upper() + + +def keyczar_symmetric_decrypt(decrypt_key, ciphertext): + """ + Decrypt the given crypto text into plain text. Returns the original + string input. Note that we first convert the string to hex notation + and then decrypt. This is reverse of the encrypt operation. + + :param decrypt_key: Symmetric AES key to use for decryption. + :type decrypt_key: :class:`keyczar.keys.AESKey` + + :param crypto: Crypto text to be decrypted. + :type crypto: ``str`` + + :rtype: ``str`` + """ + from keyczar.keys import AesKey as KeyczarAesKey # pylint: disable=import-error + from keyczar.keys import HmacKey as KeyczarHmacKey # pylint: disable=import-error + from keyczar.keyinfo import GetMode # pylint: disable=import-error + + decrypt_key = KeyczarAesKey( + decrypt_key.aes_key_string, + KeyczarHmacKey(decrypt_key.hmac_key_string, decrypt_key.hmac_key_size), + decrypt_key.size, + GetMode(decrypt_key.mode), + ) + + return decrypt_key.Decrypt(binascii.unhexlify(ciphertext)) + + +def pkcs5_pad(data): + """ + Pad data using PKCS5 + """ + pad = KEYCZAR_AES_BLOCK_SIZE - len(data) % KEYCZAR_AES_BLOCK_SIZE + data = data + pad * chr(pad).encode("utf-8") + return data + + +def pkcs5_unpad(data): + """ + Unpad data padded using PKCS5. + """ + if isinstance(data, six.binary_type): + # Make sure we are operating with a string type + data = data.decode("utf-8") + + pad = ord(data[-1]) + data = data[:-pad] + return data + + +def Base64WSEncode(s): + """ + Return Base64 web safe encoding of s. Suppress padding characters (=). + + Uses URL-safe alphabet: - replaces +, _ replaces /. Will convert s of type + unicode to string type first. + + @param s: string to encode as Base64 + @type s: string + + @return: Base64 representation of s. + @rtype: string + + NOTE: Taken from keyczar (Apache 2.0 license) + """ + if isinstance(s, six.text_type): + # Make sure input string is always converted to bytes (if not already) + s = s.encode("utf-8") + + return base64.urlsafe_b64encode(s).decode("utf-8").replace("=", "") + + +def Base64WSDecode(s): + """ + Return decoded version of given Base64 string. Ignore whitespace. + + Uses URL-safe alphabet: - replaces +, _ replaces /. Will convert s of type + unicode to string type first. + + @param s: Base64 string to decode + @type s: string + + @return: original string that was encoded as Base64 + @rtype: string + + @raise Base64DecodingError: If length of string (ignoring whitespace) is one + more than a multiple of four. + + NOTE: Taken from keyczar (Apache 2.0 license) + """ + s = "".join(s.splitlines()) + s = str(s.replace(" ", "")) # kill whitespace, make string (not unicode) + d = len(s) % 4 + + if d == 1: + raise ValueError("Base64 decoding errors") + elif d == 2: + s += "==" + elif d == 3: + s += "=" + + try: + return base64.urlsafe_b64decode(s) + except TypeError as e: + # Decoding raises TypeError if s contains invalid characters. + raise ValueError("Base64 decoding error: %s" % (six.text_type(e))) diff --git a/st2client/st2client/utils/httpclient.py b/st2client/st2client/utils/httpclient.py index ec7bebdf64..2759ba13b6 100644 --- a/st2client/st2client/utils/httpclient.py +++ b/st2client/st2client/utils/httpclient.py @@ -165,7 +165,8 @@ def delete(self, url, **kwargs): return response def _response_hook(self, response): - if self.debug: + # in case we're in testing, FakeResponse does not have a request parameter :/ + if self.debug and hasattr(response, "request"): # Log cURL request line curl_line = self._get_curl_line_for_request(request=response.request) print("# -------- begin %d request ----------" % id(self)) diff --git a/st2client/st2client/utils/jsonify.py b/st2client/st2client/utils/jsonify.py new file mode 100644 index 0000000000..bedadf70cf --- /dev/null +++ b/st2client/st2client/utils/jsonify.py @@ -0,0 +1,195 @@ +# Copyright 2020 The StackStorm Authors. +# Copyright 2019 Extreme Networks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging + +LOG = logging.getLogger(__name__) + +try: + import simplejson as json + from simplejson import JSONEncoder +except ImportError: + import json + from json import JSONEncoder + +import six +import orjson + + +__all__ = [ + "json_encode", + "json_decode", + "json_loads", + "try_loads", + "get_json_type_for_python_value", +] + +# Which json library to use for data serialization and deserialization. +# We only expose this option so we can exercise code paths with different libraries inside the +# tests for compatibility reasons +DEFAULT_JSON_LIBRARY = "orjson" + + +class GenericJSON(JSONEncoder): + def default(self, obj): # pylint: disable=method-hidden + if hasattr(obj, "__json__") and six.callable(obj.__json__): + return obj.__json__() + else: + return JSONEncoder.default(self, obj) + + +def default(obj): + if hasattr(obj, "__json__") and six.callable(obj.__json__): + return obj.__json__() + elif isinstance(obj, bytes): + # TODO: We should update the code which passes bytes to pass unicode to avoid this + # conversion here + return obj.decode("utf-8") + raise TypeError + + +def json_encode_native_json(obj, indent=4, sort_keys=False): + if not indent: + separators = (",", ":") + else: + separators = None + return json.dumps( + obj, cls=GenericJSON, indent=indent, separators=separators, sort_keys=sort_keys + ) + + +def json_encode_orjson(obj, indent=None, sort_keys=False): + option = None + + if indent: + # NOTE: We don't use indent by default since it's quite a bit slower + option = orjson.OPT_INDENT_2 + + if sort_keys: + option = option | orjson.OPT_SORT_KEYS if option else orjson.OPT_SORT_KEYS + + if option: + return orjson.dumps(obj, default=default, option=option).decode("utf-8") + + return orjson.dumps(obj, default=default).decode("utf-8") + + +def json_decode_native_json(data): + return json.loads(data) + + +def json_decode_orjson(data): + return orjson.loads(data) + + +def json_encode(obj, indent=None, sort_keys=False): + """ + Wrapper function for encoding the provided object. + + This function automatically select appropriate JSON library based on the configuration value. + + This function should be used everywhere in the code base where json.dumps() behavior is desired. + """ + json_library = DEFAULT_JSON_LIBRARY + + if json_library == "json": + return json_encode_native_json(obj=obj, indent=indent, sort_keys=sort_keys) + elif json_library == "orjson": + return json_encode_orjson(obj=obj, indent=indent, sort_keys=sort_keys) + else: + raise ValueError("Unsupported json_library: %s" % (json_library)) + + +def json_decode(data): + """ + Wrapper function for decoding the provided JSON string. + + This function automatically select appropriate JSON library based on the configuration value. + + This function should be used everywhere in the code base where json.loads() behavior is desired. + """ + json_library = DEFAULT_JSON_LIBRARY + + if json_library == "json": + return json_decode_native_json(data=data) + elif json_library == "orjson": + return json_decode_orjson(data=data) + else: + raise ValueError("Unsupported json_library: %s" % (json_library)) + + +def load_file(path): + with open(path, "r") as fd: + return json.load(fd) + + +def json_loads(obj, keys=None): + """ + Given an object, this method tries to json.loads() the value of each of the keys. If json.loads + fails, the original value stays in the object. + + :param obj: Original object whose values should be converted to json. + :type obj: ``dict`` + + :param keys: Optional List of keys whose values should be transformed. + :type keys: ``list`` + + :rtype ``dict`` or ``None`` + """ + if not obj: + return None + + if not keys: + keys = list(obj.keys()) + + for key in keys: + try: + obj[key] = json_decode(obj[key]) + except Exception: + # NOTE: This exception is not fatal so we intentionally don't log anything. + # Method behaves in "best effort" manner and dictionary value not being JSON + # string is perfectly valid (and common) scenario so we should not log anything + pass + return obj + + +def try_loads(s): + try: + return json_decode(s) if s and isinstance(s, six.string_types) else s + except: + return s + + +def get_json_type_for_python_value(value): + """ + Return JSON type string for the provided Python value. + + :rtype: ``str`` + """ + if isinstance(value, six.text_type): + return "string" + elif isinstance(value, (int, float)): + return "number" + elif isinstance(value, dict): + return "object" + elif isinstance(value, (list, tuple)): + return "array" + elif isinstance(value, bool): + return "boolean" + elif value is None: + return "null" + else: + return "unknown" diff --git a/st2client/st2client/utils/sso_interceptor.py b/st2client/st2client/utils/sso_interceptor.py new file mode 100644 index 0000000000..935a43394e --- /dev/null +++ b/st2client/st2client/utils/sso_interceptor.py @@ -0,0 +1,184 @@ +# Copyright 2020 The StackStorm Authors. +# Copyright 2019 Extreme Networks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging +from threading import Thread +import time +from urllib.parse import urlparse, parse_qs +import uuid +from http.server import BaseHTTPRequestHandler, HTTPServer + +from st2client.utils.crypto import symmetric_decrypt + +LOG = logging.getLogger(__name__) + + +# Implements a local HTTP server used to intercept calls from/to SSO endpoints :) +# via callback URLs +class SSOInterceptorProxy: + + thread = None + server = None + # Identifier to be used to access the SSO proxy (e.g. localhost:31283/) + url_id = uuid.uuid4() + # where should the proxy redirect to upon hitting it? + sso_url = None + # key that is used to decrypt the response + key = None + # token object to receive the token once it's avaiable! + token = None + + def __init__(self, key, sso_port): + + self.server = HTTPServer(("localhost", sso_port), createSSOProxyHandler(self)) + self.key = key + + LOG.debug( + "Initialized SSO interceptor proxy at port %d and url id %s, SSO URL is still pending", + self.server.server_port, + self.url_id, + ) + + self.thread = Thread(target=self.server.serve_forever) + self.thread.setDaemon(True) + self.thread.start() + + def set_sso_url(self, sso_url): + self.sso_url = sso_url + LOG.debug("SSO URL set to [%s]", sso_url) + + def get_proxy_url(self): + return "http://localhost:%d/%s" % (self.server.server_port, self.url_id) + + def get_callback_url(self): + return "http://localhost:%d/callback" % (self.server.server_port) + + def callback_received(self, token): + LOG.debug("Callback received and intercepted, token is provided :)") + self.token = token + + def get_token(self, timeout=90): + LOG.debug( + "Waiting for token to be received from SSO flow.. will timeout after [%s]s", + timeout, + ) + timeout_at = time.time() + timeout + while time.time() < timeout_at: + if self.token is not None: + return self.token + time.sleep(0.5) + + raise TimeoutError( + "Token was not received from SSO flow before the timeout of %ss" % timeout + ) + + +def createSSOProxyHandler(interceptor: SSOInterceptorProxy): + class SSOProxyServer(BaseHTTPRequestHandler): + def do_GET(self): + + o = urlparse(self.path) + qs = parse_qs(o.query) + + try: + + if o.path == "/callback": + self._handle_callback(qs.get("response", [None])[0]) + elif o.path == "/success": + self._handle_success() + elif o.path == "/%s" % interceptor.url_id: + self._handle_sso_login() + else: + self._handle_unexpected_request() + + except ValueError as e: + self.send_error(400, explain="Invalid parameter: %s" % str(e)) + except Exception as e: + LOG.debug("Unexpected internal server error! %e", e) + self.send_error(500, explain="Unexpected error!" % str(e)) + + return True + + # This request is not expected by the sso proxy + def _handle_unexpected_request(self): + self.send_error(404, explain="The selected URL does not exist!") + self.end_headers() + + # This request is to redirect the user to the proper sso place + # -- can only be achieve with the proper key :) + def _handle_sso_login(self): + LOG.debug("Intercepting SSO begin flow from the user") + self.send_response(307) + self.send_header("Location", interceptor.sso_url) + self.end_headers() + + # This request should have all the callback data we are expecting + # -- this means an encrypted key to be decrypted and used by the CLI :) + def _handle_callback(self, response): + LOG.debug("Intercepting SSO callback response!") + + if response is None: + raise ValueError( + "Expected 'response' field with encrypted key in callback!" + ) + + token = None + try: + token = symmetric_decrypt(interceptor.key, response.encode("utf-8")) + token_json = json.loads(token) + LOG.debug( + "Successful SSO login for user %s, redirecting to successful page!", + token_json.get("user", None), + ) + except: + LOG.debug("Could not understand the SSO callback response!") + raise ValueError( + "Could not understand the incoming SSO callback response" + ) + + interceptor.callback_received(token) + self.send_response(302) + self.send_header("Location", "/success") + self.end_headers() + + # self.wfile.close() + + def _handle_success(self): + self.send_response(200) + self.end_headers() + self.wfile.write( + bytes( + """ + SSO Login Successful + +
+
Successfully logged into StackStorm using SSO!
+
Please check your terminal
+
You may now close this page
+
+ + """, + "utf-8", + ) + ) + + def log_message(self, format, *args): + LOG.debug("%s " + format, "SSO Proxy: ", *args) + return + + return SSOProxyServer diff --git a/st2client/tests/unit/test_auth.py b/st2client/tests/unit/test_auth.py index e59b31dfaf..bb1d951214 100644 --- a/st2client/tests/unit/test_auth.py +++ b/st2client/tests/unit/test_auth.py @@ -15,6 +15,8 @@ from __future__ import absolute_import import os +import re +from time import sleep, time import uuid import json import mock @@ -22,6 +24,8 @@ import requests import argparse import logging +from threading import Thread +from datetime import datetime, timedelta import six @@ -29,12 +33,16 @@ from st2client import shell from st2client.models.core import add_auth_token_to_kwargs_from_env from st2client.commands.resource import add_auth_token_to_kwargs_from_cli +from st2client.utils.crypto import ( + AESKey, + read_crypto_key_from_dict, + symmetric_encrypt, +) from st2client.utils.httpclient import ( add_auth_token_to_headers, add_json_content_type_to_headers, ) - LOG = logging.getLogger(__name__) if six.PY3: @@ -165,6 +173,159 @@ def runTest(self): ) +class TestLoginSSO(TestLoginBase): + + ORIGINAL_POST_FN = requests.post + + CONFIG_FILE_NAME = "logintest.cfg" + + LOGIN_REQUEST_MOCK_KEY = read_crypto_key_from_dict( + { + "hmacKey": { + "hmacKeyString": "-qdRklvhm4xvzIfaL6Z2nmQ-2N-c4IUtNa1_BowCVfg", + "size": 256, + }, + "aesKeyString": "0UyXFjBTQ9PMyHZ0mqrvuqCSzesuFup1d6m-4Vi3vdo", + "mode": "CBC", + "size": 256, + } + ) + + TOKEN = { + "user": "stanley", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, + } + + ENCRYPTED_TOKEN = symmetric_encrypt( + LOGIN_REQUEST_MOCK_KEY, json.dumps(TOKEN) + ).decode("utf-8") + + LOGIN_REQUEST_RESPONSE = { + # This is just a placeholder name, it's all mocked :) + "sso_url": "http://keycloak/realms/StackStorm/protocol/saml?SAMLRequest=fZFRS8MwFIX%2FSsl7TJPV1Ya1MB3iYOJYqw%2B%2BSJpFF2yTmXsr%2Bu%2FNplNU2OM53HNzvtyJg1ROB9y4lXkZDGDy1ncOZLRLMgQnvQIbpeoNSNSynl4vpDhJ5TZ49Np35DvAjwcUgAlovSPJfFYSu34QOs9yk2f0LB8rmqm2oAUfjempUCLX3Ohx25LkzgSIqZLEJTEKMJi5A1QOo5UKQdOC8qLhuRRCZuKeJLOIYZ3CfWqDuJWMdV6rbuMB5SjlnAWjuh5YjUo%2F1%2BhDzw48DFQfoZZf8ty6tXVPx9HazyGQV02zpMubuiHJ9IB74R0MvQm1Ca9Wm9vV4n8ppuIFGIBn0ejaWIpUk%2Fijco8bksvYUOHxEjvHrunjflQahxbfSfX3pQn7WVvtxO%2FrVx8%3D&RelayState=%7B%22referer%22%3A+%22http%3A%2F%2Flocalhost%3A34000%2Fcallback%22%7D", + "expiry": (datetime.now() + timedelta(hours=3)).strftime( + "%Y-%m-%dT%H:%M:%S.%f" + )[:-3] + + "000+00:00", + } + + @mock.patch.object(AESKey, "generate", return_value=LOGIN_REQUEST_MOCK_KEY) + @mock.patch( + "requests.post", + return_value=base.FakeResponse(json.dumps(LOGIN_REQUEST_RESPONSE), 200, "OK"), + ) + def runTest(self, mock_aeskey_generate, mock_post): + """Test 'st2 login --sso' functionality""" + + expected_username = self.TOKEN["user"] + args = [ + "--config", + self.CONFIG_FILE, + "login", + "--sso", + "--no-sso-browser", + "--sso-port", + "34000", + ] + + def handle_sso_flow(): + # Waiting for SSO link on the CLI + LOG.debug("Waiting for SSO link") + match = None + timeout_at = time() + 5 + while not match and timeout_at > time(): + sleep(1) + self.stdout.seek(0) + buffer = self.stdout.read() + LOG.debug("STDOUT buffer has: %s", buffer) + match = re.search(r"http://localhost:34000/\S+", buffer, re.MULTILINE) + self.assertIsNotNone(match) + + # Hitting the localhost login url + login_url = match[0] + LOG.debug("GETting SSO login to %s", login_url) + response = requests.get(login_url, allow_redirects=False) + self.assertEquals(response.status_code, 307) + self.assertEquals( + response.headers["Location"], self.LOGIN_REQUEST_RESPONSE["sso_url"] + ) + + # Ignoring IDP flow and just hittin callback with proper response :) + LOG.debug("Calling back to local server") + response = requests.get( + "http://localhost:34000/callback", + params={"response": self.ENCRYPTED_TOKEN}, + allow_redirects=False, + ) + self.assertEquals(response.status_code, 302) + self.assertEquals(response.headers["Location"], "/success") + LOG.debug("Finished SSO flow") + + def run_shell(): + self.shell.run(args) + + shellThread = Thread(target=run_shell) + shellThread.start() + + handle_sso_flow() + + shellThread.join() + + with open(self.CONFIG_FILE, "r") as config_file: + for line in config_file.readlines(): + print(line) + # Make sure certain values are not present + self.assertNotIn("password", line) + self.assertNotIn("olduser", line) + + # Make sure configured username is what we expect + if "username" in line: + self.assertEqual(line.split(" ")[2][:-1], expected_username) + + # validate token was created + self.assertTrue( + os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username)) + ) + + +class TestLoginWithMissingUsername(TestLoginBase): + + CONFIG_FILE_NAME = "logintest.cfg" + + TOKEN = { + "user": "st2admin", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, + } + + @mock.patch.object( + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")), + ) + def runTest(self): + """Test 'st2 login' functionality missing the username and should fail""" + + expected_username = self.TOKEN["user"] # noqa + args = [ + "--config", + self.CONFIG_FILE, + "login", + "--password", + "Password1!", + ] + + self.shell.run(args) + self.assertIn( + "Username expected when not using SSO login", self.stdout.getvalue() + ) + + class TestLoginIntPwdAndConfig(TestLoginBase): CONFIG_FILE_NAME = "logintest.cfg" diff --git a/st2client/tests/unit/test_shell.py b/st2client/tests/unit/test_shell.py index 5eb27714ca..425347db97 100644 --- a/st2client/tests/unit/test_shell.py +++ b/st2client/tests/unit/test_shell.py @@ -37,6 +37,7 @@ from st2common.models.db.auth import TokenDB from tests import base + LOG = logging.getLogger(__name__) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/st2common/st2common/exceptions/auth.py b/st2common/st2common/exceptions/auth.py index 5eab1915f5..15a04a5680 100644 --- a/st2common/st2common/exceptions/auth.py +++ b/st2common/st2common/exceptions/auth.py @@ -31,6 +31,7 @@ "AmbiguousUserError", "NotServiceUserError", "SSOVerificationError", + "SSORequestNotFoundError", ] @@ -38,6 +39,10 @@ class TokenNotProvidedError(StackStormBaseException): pass +class SSORequestNotFoundError(StackStormBaseException): + pass + + class TokenNotFoundError(StackStormDBObjectNotFoundError): pass diff --git a/st2common/st2common/models/db/auth.py b/st2common/st2common/models/db/auth.py index 2531ecb11a..ccb8558cea 100644 --- a/st2common/st2common/models/db/auth.py +++ b/st2common/st2common/models/db/auth.py @@ -16,6 +16,7 @@ from __future__ import absolute_import import copy +from enum import Enum import mongoengine as me from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE @@ -25,7 +26,7 @@ from st2common.rbac.backends import get_rbac_backend from st2common.util import date as date_utils -__all__ = ["UserDB", "TokenDB", "ApiKeyDB"] +__all__ = ["UserDB", "TokenDB", "ApiKeyDB", "SSORequestDB"] class UserDB(stormbase.StormFoundationDB): @@ -85,6 +86,29 @@ class TokenDB(stormbase.StormFoundationDB): service = me.BooleanField(required=True, default=False) +class SSORequestDB(stormbase.StormFoundationDB): + class Type(Enum): + CLI = "cli" + WEB = "web" + + """ + An entity representing a SSO request. + + Attribute: + request_id: Reference to the SSO request unique ID + expiry: Time at which this request expires. + type: What type of SSO request is this? web/cli + + -- cli -- + key: Symmetric key used to encrypt/decrypt contents from/to the CLI. + """ + + request_id = me.StringField(required=True) + key = me.StringField(required=False, unique=False) + expiry = me.DateTimeField(required=True) + type = me.EnumField(Type, required=True) + + class ApiKeyDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): """ An entity representing an API key object. @@ -127,4 +151,4 @@ def mask_secrets(self, value): return result -MODELS = [UserDB, TokenDB, ApiKeyDB] +MODELS = [UserDB, TokenDB, ApiKeyDB, SSORequestDB] diff --git a/st2common/st2common/openapi.yaml b/st2common/st2common/openapi.yaml index e86e42727d..de2c693a27 100644 --- a/st2common/st2common/openapi.yaml +++ b/st2common/st2common/openapi.yaml @@ -4526,10 +4526,11 @@ paths: schema: $ref: '#/definitions/Error' security: [] - /auth/v1/sso/request: + + /auth/v1/sso/request/web: get: - operationId: st2auth.controllers.v1.sso:sso_request_controller.get - description: Redirects to the SSO Idp login page. + operationId: st2auth.controllers.v1.sso:sso_request_controller.get_web + description: Redirects to the SSO Idp login page from a user that's using the browser. parameters: - name: referer in: header @@ -4539,6 +4540,31 @@ paths: '307': description: Temporary redirect security: [] + + /auth/v1/sso/request/cli: + post: + operationId: st2auth.controllers.v1.sso:sso_request_controller.post_cli + description: Issues an encrypted SSO login request for a CLI + parameters: + - name: response + in: body + description: SSO request with callback and key encryption + schema: + type: object + required: + - key + - callback_url + properties: + key: + type: string + description: The symmetric key to be used to encrypt contents of callback + callback_url: + type: string + description: What URL to be called back once the response from SSO is received + responses: + '200': + description: SSO request valid + security: [] /auth/v1/sso/callback: post: operationId: st2auth.controllers.v1.sso:idp_callback_controller.post @@ -4552,6 +4578,8 @@ paths: responses: '200': description: SSO response valid + '302': + description: SSO response valid and callback URL returned '401': description: Invalid or missing credentials has been provided schema: diff --git a/st2common/st2common/openapi.yaml.j2 b/st2common/st2common/openapi.yaml.j2 index f053f0f3d0..9c2177bc41 100644 --- a/st2common/st2common/openapi.yaml.j2 +++ b/st2common/st2common/openapi.yaml.j2 @@ -4522,10 +4522,11 @@ paths: schema: $ref: '#/definitions/Error' security: [] - /auth/v1/sso/request: + + /auth/v1/sso/request/web: get: - operationId: st2auth.controllers.v1.sso:sso_request_controller.get - description: Redirects to the SSO Idp login page. + operationId: st2auth.controllers.v1.sso:sso_request_controller.get_web + description: Redirects to the SSO Idp login page from a user that's using the browser. parameters: - name: referer in: header @@ -4535,6 +4536,31 @@ paths: '307': description: Temporary redirect security: [] + + /auth/v1/sso/request/cli: + post: + operationId: st2auth.controllers.v1.sso:sso_request_controller.post_cli + description: Issues an encrypted SSO login request for a CLI + parameters: + - name: response + in: body + description: SSO request with callback and key encryption + schema: + type: object + required: + - key + - callback_url + properties: + key: + type: string + description: The symmetric key to be used to encrypt contents of callback + callback_url: + type: string + description: What URL to be called back once the response from SSO is received + responses: + '200': + description: SSO request valid + security: [] /auth/v1/sso/callback: post: operationId: st2auth.controllers.v1.sso:idp_callback_controller.post @@ -4548,6 +4574,8 @@ paths: responses: '200': description: SSO response valid + '302': + description: SSO response valid and callback URL returned '401': description: Invalid or missing credentials has been provided schema: diff --git a/st2common/st2common/persistence/auth.py b/st2common/st2common/persistence/auth.py index a8fad7488f..78b168c5da 100644 --- a/st2common/st2common/persistence/auth.py +++ b/st2common/st2common/persistence/auth.py @@ -15,6 +15,7 @@ from __future__ import absolute_import from st2common.exceptions.auth import ( + SSORequestNotFoundError, TokenNotFoundError, ApiKeyNotFoundError, UserNotFoundError, @@ -22,7 +23,7 @@ NoNicknameOriginProvidedError, ) from st2common.models.db import MongoDBAccess -from st2common.models.db.auth import UserDB, TokenDB, ApiKeyDB +from st2common.models.db.auth import SSORequestDB, UserDB, TokenDB, ApiKeyDB from st2common.persistence.base import Access from st2common.util import hash as hash_utils @@ -59,6 +60,44 @@ def _get_by_object(cls, object): return cls.get_by_name(name) +class SSORequest(Access): + impl = MongoDBAccess(SSORequestDB) + + @classmethod + def _get_impl(cls): + return cls.impl + + @classmethod + def add_or_update(cls, model_object, publish=True, validate=True): + if not getattr(model_object, "request_id", None): + raise ValueError("SSO Request ID is not provided in the object.") + if not getattr(model_object, "type", None): + raise ValueError("SSO request type is not defined in the object") + if not getattr(model_object, "expiry", None): + raise ValueError("SSO request expiry is not provided in the object.") + return super(SSORequest, cls).add_or_update( + model_object, publish=publish, validate=validate + ) + + @classmethod + def get(cls, value): + result = cls.query(id=value).first() + + if not result: + raise SSORequestNotFoundError() + + return result + + @classmethod + def get_by_request_id(cls, value): + result = cls.query(request_id=value).first() + + if not result: + raise SSORequestNotFoundError() + + return result + + class Token(Access): impl = MongoDBAccess(TokenDB) diff --git a/st2common/st2common/services/access.py b/st2common/st2common/services/access.py index 9d88c39c42..73433ee849 100644 --- a/st2common/st2common/services/access.py +++ b/st2common/st2common/services/access.py @@ -21,16 +21,27 @@ from st2common.util import isotime from st2common.util import date as date_utils -from st2common.exceptions.auth import TokenNotFoundError, UserNotFoundError +from st2common.exceptions.auth import ( + TokenNotFoundError, + UserNotFoundError, +) from st2common.exceptions.auth import TTLTooLargeException -from st2common.models.db.auth import TokenDB, UserDB -from st2common.persistence.auth import Token, User +from st2common.models.db.auth import SSORequestDB, TokenDB, UserDB +from st2common.persistence.auth import SSORequest, Token, User from st2common import log as logging -__all__ = ["create_token", "delete_token"] +__all__ = [ + "create_token", + "delete_token", + "create_cli_sso_request", + "create_web_sso_request", + "get_sso_request_by_request_id", +] LOG = logging.getLogger(__name__) +DEFAULT_SSO_REQUEST_TTL = 120 + def create_token( username, ttl=None, metadata=None, add_missing_user=True, service=False @@ -105,3 +116,52 @@ def delete_token(token): pass except Exception: raise + + +def create_cli_sso_request(request_id, key, ttl=DEFAULT_SSO_REQUEST_TTL): + """ + :param request_id: ID of the SSO request that is being created (usually uuid format prepended by _) + :type request_id: ``str`` + + :param key: Symmetric key used to encrypt/decrypt the request between the CLI and the server + :type key: ``str`` + + :param ttl: SSO request TTL (in seconds). + :type ttl: ``int`` + """ + + return _create_sso_request(request_id, ttl, SSORequestDB.Type.CLI, key=key) + + +def create_web_sso_request(request_id, ttl=DEFAULT_SSO_REQUEST_TTL): + """ + :param request_id: ID of the SSO request that is being created (usually uuid format prepended by _) + :type request_id: ``str`` + + :param ttl: SSO request TTL (in seconds). + :type ttl: ``int`` + """ + + return _create_sso_request(request_id, ttl, SSORequestDB.Type.WEB) + + +def _create_sso_request(request_id, ttl, type, **kwargs) -> SSORequestDB: + + expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) + + request = SSORequestDB(request_id=request_id, expiry=expiry, type=type, **kwargs) + SSORequest.add_or_update(request) + + expire_string = isotime.format(expiry, offset=False) + + LOG.audit( + 'Created SAML request with ID "%s" set to expire at "%s" of type "%s".' + % (request_id, expire_string, type) + ) + + return request + + +def get_sso_request_by_request_id(request_id) -> SSORequestDB: + request_db = SSORequest.get_by_request_id(request_id) + return request_db diff --git a/st2common/st2common/util/crypto.py b/st2common/st2common/util/crypto.py index 0aea24763c..e6be862101 100644 --- a/st2common/st2common/util/crypto.py +++ b/st2common/st2common/util/crypto.py @@ -184,16 +184,33 @@ def read_crypto_key(key_path): content = json_decode(content) + try: + return read_crypto_key_from_dict(content) + except KeyError as e: + msg = 'Invalid or malformed key file "%s": %s' % (key_path, six.text_type(e)) + raise KeyError(msg) + + +def read_crypto_key_from_dict(key_dict): + """ + Read crypto key from provided Keyczar JSON-format dict and return parsed AESKey object. + + :param key_dict: A dictionary with a key in Keyczar format (same keys as the JSON). + :type key_dict: ``dict`` + + :rtype: :class:`AESKey` + """ + try: aes_key = AESKey( - aes_key_string=content["aesKeyString"], - hmac_key_string=content["hmacKey"]["hmacKeyString"], - hmac_key_size=content["hmacKey"]["size"], - mode=content["mode"].upper(), - size=content["size"], + aes_key_string=key_dict["aesKeyString"], + hmac_key_string=key_dict["hmacKey"]["hmacKeyString"], + hmac_key_size=key_dict["hmacKey"]["size"], + mode=key_dict["mode"].upper(), + size=key_dict["size"], ) except KeyError as e: - msg = 'Invalid or malformed key file "%s": %s' % (key_path, six.text_type(e)) + msg = "Invalid or malformed AES key dictionary: %s" % (six.text_type(e)) raise KeyError(msg) return aes_key diff --git a/st2common/tests/unit/services/test_access.py b/st2common/tests/unit/services/test_access.py index 4f7d8169b4..7ca61b358b 100644 --- a/st2common/tests/unit/services/test_access.py +++ b/st2common/tests/unit/services/test_access.py @@ -18,6 +18,7 @@ import uuid from oslo_config import cfg +from st2common.models.db.auth import SSORequestDB from st2tests.base import DbTestCase from st2common.util import isotime from st2common.util import date as date_utils @@ -30,6 +31,8 @@ USERNAME = "manas" +SSO_REQUEST_ID = "a58fa0cd-61c8-4bd9-a2e7-a4497d6aca68" + class AccessServiceTest(DbTestCase): @classmethod @@ -106,3 +109,37 @@ def test_create_token_service_token_can_use_arbitrary_ttl(self): self.assertRaises( TTLTooLargeException, access.create_token, USERNAME, ttl=ttl, service=False ) + + def test_create_cli_sso_request(self): + request = access.create_cli_sso_request(SSO_REQUEST_ID, None, 20) + self.assertIsNotNone(request) + self.assertEqual(request.type, SSORequestDB.Type.CLI) + self.assertEqual(request.request_id, SSO_REQUEST_ID) + self.assertLessEqual( + abs( + request.expiry.timestamp() + - date_utils.get_datetime_utc_now().timestamp() + - 20 + ), + 2, + ) + + def test_create_web_sso_request(self): + request = access.create_web_sso_request(SSO_REQUEST_ID, 20) + self.assertIsNotNone(request) + self.assertEqual(request.type, SSORequestDB.Type.WEB) + self.assertEqual(request.request_id, SSO_REQUEST_ID) + self.assertLessEqual( + abs( + request.expiry.timestamp() + - date_utils.get_datetime_utc_now().timestamp() + - 20 + ), + 2, + ) + + def test_get_sso_request_by_id(self): + access.create_web_sso_request(SSO_REQUEST_ID, 20) + request = access.get_sso_request_by_request_id(SSO_REQUEST_ID) + self.assertIsNotNone(request) + self.assertEqual(request.request_id, SSO_REQUEST_ID) diff --git a/st2common/tests/unit/test_db_auth.py b/st2common/tests/unit/test_db_auth.py index b159580505..3cae8be5cb 100644 --- a/st2common/tests/unit/test_db_auth.py +++ b/st2common/tests/unit/test_db_auth.py @@ -14,14 +14,16 @@ # limitations under the License. from __future__ import absolute_import -from st2common.models.db.auth import UserDB +import datetime +from st2common.models.db.auth import SSORequestDB, UserDB from st2common.models.db.auth import TokenDB from st2common.models.db.auth import ApiKeyDB -from st2common.persistence.auth import User +from st2common.persistence.auth import SSORequest, User from st2common.persistence.auth import Token from st2common.persistence.auth import ApiKey -from st2common.util.date import get_datetime_utc_now +from st2common.util.date import add_utc_tz, get_datetime_utc_now from st2tests import DbTestCase +from mongoengine.errors import ValidationError from tests.unit.base import BaseDBModelCRUDTestCase @@ -58,3 +60,61 @@ class ApiKeyDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): persistance_class = ApiKey model_class_kwargs = {"user": "pony", "key_hash": "token-token-token-token"} update_attribute_name = "user" + + +class SSORequestDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): + model_class = SSORequestDB + persistance_class = SSORequest + model_class_kwargs = { + "request_id": "48144c2b-7969-4708-ba1d-96fd7d05393f", + "expiry": add_utc_tz( + datetime.datetime.strptime("2050-01-05T10:00:00", "%Y-%m-%dT%H:%M:%S") + ), + "type": SSORequestDB.Type.CLI, + } + update_attribute_name = "request_id" + + def _save_model(self, **kwargs): + model_db = self.model_class(**kwargs) + self.persistance_class.add_or_update(model_db) + + def test_missing_parameters(self): + + self.assertRaises( + ValueError, + self._save_model, + **{ + "request_id": self.model_class_kwargs["request_id"], + "expiry": self.model_class_kwargs["expiry"], + }, + ) + + self.assertRaises( + ValueError, + self._save_model, + **{ + "request_id": self.model_class_kwargs["request_id"], + "type": self.model_class_kwargs["type"], + }, + ) + + self.assertRaises( + ValueError, + self._save_model, + **{ + "type": self.model_class_kwargs["type"], + "expiry": self.model_class_kwargs["expiry"], + }, + ) + + def test_invalid_parameters(self): + + self.assertRaises( + ValidationError, + self._save_model, + **{ + "type": "invalid", + "expiry": self.model_class_kwargs["expiry"], + "request_id": self.model_class_kwargs["request_id"], + }, + )