Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 128 additions & 27 deletions application_sdk/services/secretstore.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
"""Unified secret store service for the application."""
"""Unified secret store service for the application.

Logic summary:

1. Fetch credential config from state store.

2. Determine mode: Multi-key if (credentialSource == 'direct' OR secret-path is defined), Single-key otherwise.

3. Fetch secrets accordingly: Multi-key uses secret_path if credentialSource == "agent" else credential_guid, Single-key fetches each field individually.

4. Merge & resolve secrets.
"""

import collections.abc
import copy
import json
import uuid
from enum import Enum
from typing import Any, Dict

from dapr.clients import DaprClient
Expand All @@ -23,8 +35,22 @@
logger = get_logger(__name__)


class CredentialSource(Enum):
"""Enumeration of credential source types."""

DIRECT = "direct"
AGENT = "agent"


class SecretMode(Enum):
"""Enumeration of secret retrieval modes."""

MULTI_KEY = "multi-key"
SINGLE_KEY = "single-key"


class SecretStore:
"""Unified secret store service for handling secret management."""
"""Unified secret store service for handling secret management across providers."""

@classmethod
async def get_credentials(cls, credential_guid: str) -> Dict[str, Any]:
Expand All @@ -33,6 +59,8 @@ async def get_credentials(cls, credential_guid: str) -> Dict[str, Any]:
This method retrieves credential configuration from the state store and resolves
any secret references by fetching actual values from the secret store.

Supports Multi-key mode (direct / has secret-path) and Single-key mode (no secret-path, non-direct).

Args:
credential_guid (str): The unique GUID of the credential configuration to resolve.

Expand All @@ -48,7 +76,6 @@ async def get_credentials(cls, credential_guid: str) -> Dict[str, Any]:
>>> creds = await SecretStore.get_credentials("db-cred-abc123")
>>> print(f"Connecting to {creds['host']}:{creds['port']}")
>>> # Password is automatically resolved from secret store

>>> # Handle resolution errors
>>> try:
... creds = await SecretStore.get_credentials("invalid-guid")
Expand All @@ -62,13 +89,42 @@ async def _get_credentials_async(credential_guid: str) -> Dict[str, Any]:
credential_guid, StateType.CREDENTIALS
)

# Fetch secret data from secret store
secret_key = credential_config.get("secret-path", credential_guid)
secret_data = SecretStore.get_secret(secret_key=secret_key)
credential_source_str = credential_config.get("credentialSource", "direct")
try:
credential_source = CredentialSource(credential_source_str)
except ValueError:
credential_source = CredentialSource.DIRECT
secret_path = credential_config.get("secret-path")

secret_data: Dict[str, Any] = {}

# Decide mode
if credential_source == CredentialSource.DIRECT or secret_path:
mode = SecretMode.MULTI_KEY
else:
mode = SecretMode.SINGLE_KEY

# Multi-key secret fetch (direct or has secret-path)
if mode == SecretMode.MULTI_KEY:
key_to_fetch = (
secret_path
if credential_source == CredentialSource.AGENT
else credential_guid
)
try:
logger.debug(f"Fetching multi-key secret from '{key_to_fetch}'")
secret_data = cls.get_secret(secret_key=key_to_fetch)
except Exception as e:
logger.warning(
f"Failed to fetch secret bundle '{key_to_fetch}': {e}"
)

# Single-key mode → per-field secret lookup
else:
secret_data = cls._fetch_single_key_secrets(credential_config)

# Resolve credentials
credential_source = credential_config.get("credentialSource", "direct")
if credential_source == "direct":
# Merge or resolve references
if credential_source == CredentialSource.DIRECT:
credential_config.update(secret_data)
return credential_config
else:
Expand All @@ -78,12 +134,45 @@ async def _get_credentials_async(credential_guid: str) -> Dict[str, Any]:
# Run async operations directly
return await _get_credentials_async(credential_guid)
except Exception as e:
logger.error(f"Error resolving credentials: {str(e)}")
logger.error(f"Error resolving credentials for {credential_guid}: {str(e)}")
raise CommonError(
CommonError.CREDENTIALS_RESOLUTION_ERROR,
f"Failed to resolve credentials: {str(e)}",
)

# Secret resolution helpers

@classmethod
def _fetch_single_key_secrets(
cls, credential_config: Dict[str, Any]
) -> Dict[str, Any]:
"""Fetch secrets in single-key mode by looking up each field individually.

Args:
credential_config: The credential configuration dictionary

Returns:
Dictionary containing collected secret values
"""
logger.debug("Single-key mode: fetching secrets per field")
collected = {}
for field, value in credential_config.items():
if not isinstance(value, str):
continue
try:
single_secret = cls.get_secret(value)
if single_secret:
for k, v in single_secret.items():
# Only filter out None and empty strings, not all falsy values.
# This preserves valid secret values like False, 0, 0.0 which are
# legitimate secret values that should not be excluded.
if v is None or v == "":
continue
collected[k] = v
except Exception as e:
logger.debug(f"Skipping '{field}' → '{value}' ({e})")
return collected

@classmethod
def resolve_credentials(
cls, credential_config: Dict[str, Any], secret_data: Dict[str, Any]
Expand All @@ -106,7 +195,6 @@ def resolve_credentials(
>>> secrets = {"db_password_key": "actual_secret_password"}
>>> resolved = SecretStore.resolve_credentials(config, secrets)
>>> print(resolved) # {"host": "db.example.com", "password": "actual_secret_password"}

>>> # Resolution with nested 'extra' fields
>>> config = {
... "host": "db.example.com",
Expand Down Expand Up @@ -143,7 +231,7 @@ def get_deployment_secret(cls) -> Dict[str, Any]:

Returns:
Dict[str, Any]: Deployment configuration data, or empty dict if
component is unavailable or fetch fails.
component is unavailable or fetch fails.

Examples:
>>> # Get deployment configuration
Expand All @@ -153,15 +241,14 @@ def get_deployment_secret(cls) -> Dict[str, Any]:
... print(f"Region: {config.get('region')}")
>>> else:
... print("No deployment configuration available")

>>> # Use in application initialization
>>> deployment_config = SecretStore.get_deployment_secret()
>>> if deployment_config.get('debug_mode'):
... logging.getLogger().setLevel(logging.DEBUG)
"""
if not is_component_registered(DEPLOYMENT_SECRET_STORE_NAME):
logger.warning(
f"Deployment secret store component '{DEPLOYMENT_SECRET_STORE_NAME}' is not registered"
f"Deployment secret store component '{DEPLOYMENT_SECRET_STORE_NAME}' not registered."
)
return {}

Expand All @@ -182,7 +269,7 @@ def get_secret(

Args:
secret_key (str): Key of the secret to fetch from the secret store.
component_name (str, optional): Name of the Dapr component to fetch from.
component_name (str): Name of the Dapr component to fetch from.
Defaults to SECRET_STORE_NAME.

Returns:
Expand All @@ -199,7 +286,6 @@ def get_secret(
>>> # Get database credentials
>>> db_secret = SecretStore.get_secret("database-credentials")
>>> print(f"Host: {db_secret.get('host')}")

>>> # Get from specific component
>>> api_secret = SecretStore.get_secret(
... "api-keys",
Expand All @@ -217,7 +303,7 @@ def get_secret(
return cls._process_secret_data(dapr_secret_object.secret)
except Exception as e:
logger.error(
f"Failed to fetch secret using component {component_name}: {str(e)}"
f"Failed to fetch secret using component '{component_name}': {str(e)}"
)
raise

Expand All @@ -235,17 +321,34 @@ def _process_secret_data(cls, secret_data: Any) -> Dict[str, Any]:
if isinstance(secret_data, collections.abc.Mapping):
secret_data = dict(secret_data)

# If the dict has a single key and its value is a JSON string, parse it
if len(secret_data) == 1 and isinstance(next(iter(secret_data.values())), str):
# Handle single-key secrets gracefully
if len(secret_data) == 1:
k, v = next(iter(secret_data.items()))
return cls._handle_single_key_secret(k, v)

return secret_data

# Utility helpers

@classmethod
def _handle_single_key_secret(cls, key: str, value: Any) -> Dict[str, Any]:
"""Handle single-key secret by attempting to parse JSON value.

Args:
key: The secret key.
value: The secret value (may be a JSON string).

Returns:
Dictionary with parsed JSON if value is valid JSON dict, otherwise {key: value}.
"""
if isinstance(value, str):
try:
parsed = json.loads(next(iter(secret_data.values())))
parsed = json.loads(value)
if isinstance(parsed, dict):
secret_data = parsed
except Exception as e:
logger.error(f"Failed to parse secret data: {e}")
return parsed
except Exception:
pass

return secret_data
return {key: value}

@classmethod
def apply_secret_values(
Expand Down Expand Up @@ -275,7 +378,6 @@ def apply_secret_values(
... "db_password": "secure_db_password"
... }
>>> resolved = SecretStore.apply_secret_values(source, secrets)

>>> # With nested extra fields
>>> source = {
... "host": "api.example.com",
Expand Down Expand Up @@ -328,7 +430,6 @@ async def save_secret(cls, config: Dict[str, Any]) -> str:
... }
>>> guid = await SecretStore.save_secret(config)
>>> print(f"Stored credentials with GUID: {guid}")

>>> # Later retrieve these credentials
>>> retrieved = await SecretStore.get_credentials(guid)
"""
Expand Down
Loading