Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ dependencies = [
"python-multipart>=0.0.27",
"jinja2>=3.1",
"py-key-value-aio[disk]",
"aiosqlite>=0.20",
"asyncpg>=0.30",
"pyjwt>=2.12.1",
"argon2-cffi>=25.1.0",
"base58>=2.1.1",
Expand All @@ -51,6 +53,7 @@ dev = [
"ruff>=0.9",
"ty",
"httpx>=0.28.1",
"pre-commit>=4.6.0",
]

[project.scripts]
Expand Down
31 changes: 14 additions & 17 deletions src/authsome/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,18 @@
from authsome.auth.sessions import AuthSessionStore
from authsome.errors import AuthsomeError
from authsome.identity.proof import ReplayCache
from authsome.paths import get_server_log_path
from authsome.server.analytics import init_posthog, shutdown_posthog
from authsome.server.dependencies import (
create_app_store,
create_hosted_account_service,
create_identity_bootstrap_service,
create_identity_claim_registry,
create_ownership_resolver,
create_principal_vault_binding_registry,
create_store,
create_vault,
create_vault_registry,
get_identity_registry_path,
get_server_base_url,
get_server_log_path,
load_server_config,
load_ui_session_signing_secret,
)
from authsome.server.registries import IdentityRegistrationError, IdentityRegistry
from authsome.server.routes.auth import router as auth_router
from authsome.server.routes.connections import router as connections_router
from authsome.server.routes.health import router as health_router
Expand All @@ -38,33 +33,35 @@
from authsome.server.routes.proxy import router as proxy_router
from authsome.server.routes.ui import UiAuthRequiredError
from authsome.server.routes.ui import router as ui_router
from authsome.server.store.repositories import IdentityRegistrationError
from authsome.server.ui_sessions import UiSessionStore


@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage daemon lifecycle."""
app.state.store = await create_app_store()
app.state.server_config = load_server_config(app.state.store.home)
app.state.store = await create_store()
app.state.server_config = await load_server_config(app.state.store)
audit.setup(get_server_log_path(app.state.store.home))
app.state.vault = await create_vault(app.state.store)
app.state.vault = await create_vault(app.state.store.home)
app.state.auth_sessions = AuthSessionStore()
app.state.ui_sessions = UiSessionStore(load_ui_session_signing_secret(app.state.store.home))
app.state.proof_replay_cache = ReplayCache()
app.state.identity_registry = IdentityRegistry(get_identity_registry_path(app.state.store.home))
app.state.vault_registry = create_vault_registry(app.state.store.home)
app.state.identity_claim_registry = create_identity_claim_registry(app.state.store.home)
app.state.principal_vault_binding_registry = create_principal_vault_binding_registry(app.state.store.home)
app.state.hosted_account_service = create_hosted_account_service(app.state.store.home)
app.state.identity_registry = app.state.store.identity_registry
app.state.vault_registry = app.state.store.vaults
app.state.identity_claim_registry = app.state.store.identity_claims
app.state.principal_vault_binding_registry = app.state.store.principal_vault_bindings
app.state.provider_definition_repository = app.state.store.provider_definitions
app.state.hosted_account_service = create_hosted_account_service(app.state.store)
app.state.server_base_url = get_server_base_url()
init_posthog()
app.state.identity_bootstrap = create_identity_bootstrap_service(
app.state.identity_registry,
app.state.ui_sessions,
home=app.state.store.home,
store=app.state.store,
server_base_url=app.state.server_base_url,
)
app.state.ownership_resolver = create_ownership_resolver(app.state.store.home)
app.state.ownership_resolver = create_ownership_resolver(app.state.store)
app.state.ownership_cache = {}
yield
shutdown_posthog()
Expand Down
39 changes: 13 additions & 26 deletions src/authsome/server/credential_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
IdentityNotFoundError,
InvalidProviderSchemaError,
OperationNotAllowedError,
ProviderAlreadyRegisteredError,
ProviderNotFoundError,
RefreshFailedError,
TokenExpiredError,
UnsupportedFlowError,
)
from authsome.server.store.repositories import ProviderDefinitionRepository
from authsome.utils import build_store_key, format_duration, is_filesystem_safe, parse_store_key, utc_now
from authsome.vault import Vault

Expand Down Expand Up @@ -86,6 +86,7 @@ class AuthService:
def __init__(
self,
vault: Vault,
provider_definitions: ProviderDefinitionRepository,
identity: str | None = None,
principal_id: str | None = None,
vault_id: str | None = None,
Expand All @@ -96,6 +97,7 @@ def __init__(
self._principal_id = principal_id
self._vault_id = vault_id
self._deployment_mode = "hosted" if deployment_mode == "hosted" else "local"
self._provider_definitions = provider_definitions
self._bundled: dict[str, ProviderDefinition] = self._load_bundled_providers()

@property
Expand Down Expand Up @@ -150,15 +152,8 @@ def _load_bundled_providers() -> dict[str, ProviderDefinition]:
return bundled

async def _load_custom_providers(self) -> dict[str, ProviderDefinition]:
providers: dict[str, ProviderDefinition] = {}
try:
for name in await self._vault.list(collection="providers"):
raw = await self._vault.get(name, collection="providers")
if raw:
providers[name] = ProviderDefinition.model_validate_json(raw)
except Exception as exc:
logger.warning("Could not load custom providers: {}", exc)
return providers
providers = await self._provider_definitions.list()
return {provider.name: provider for provider in providers}

async def list_providers(self) -> list[ProviderDefinition]:
providers = {**self._bundled, **(await self._load_custom_providers())}
Expand All @@ -171,17 +166,16 @@ async def list_providers_by_source(self) -> dict[str, list[ProviderDefinition]]:
return {"bundled": bundled_list, "custom": custom_list}

async def get_provider(self, provider: str) -> ProviderDefinition:
raw = await self._vault.get(provider, collection="providers")
if raw:
return ProviderDefinition.model_validate_json(raw)
custom = await self._provider_definitions.get(provider)
if custom is not None:
return custom
if provider in self._bundled:
return self._bundled[provider]
raise ProviderNotFoundError(provider)

async def is_local_provider(self, provider: str) -> bool:
"""Check if a provider is a custom/local provider."""
val = await self._vault.get(provider, collection="providers")
return val is not None
return await self._provider_definitions.get(provider) is not None

async def resolve_credentials(self, **kwargs: Any) -> dict[str, Any]:
"""Resolve credentials for a provider/connection pair."""
Expand All @@ -200,20 +194,12 @@ async def resolve_credentials(self, **kwargs: Any) -> dict[str, Any]:
async def register_provider(self, definition: ProviderDefinition, *, force: bool = False) -> None:
self._ensure_local_provider_admin_operation_allowed("register", definition.name)
self._validate_provider(definition)
has_custom = (await self._vault.get(definition.name, collection="providers")) is not None
if force or not has_custom:
await self._vault.put(
definition.name,
definition.model_dump_json(indent=2, exclude_none=True),
collection="providers",
)
else:
raise ProviderAlreadyRegisteredError(definition.name)
await self._provider_definitions.save(definition, force=force)
logger.info("Registered provider: {}", definition.name)

async def remove_provider(self, name: str) -> bool:
"""Remove a custom provider. Returns True if removed."""
return await self._vault.delete(name, collection="providers")
return await self._provider_definitions.delete(name)

def _ensure_local_provider_admin_operation_allowed(self, operation: str, provider: str) -> None:
if is_admin_principal(self._principal_id):
Expand Down Expand Up @@ -777,6 +763,7 @@ async def revoke(self, provider: str, vault_ids: list[str] | None = None) -> Non
principal_id=self._principal_id,
vault_id=vault_id,
deployment_mode=self._deployment_mode,
provider_definitions=self._provider_definitions,
)
meta_key = build_store_key(vault=vault_id, provider=provider, record_type="metadata")
existing_json = await self._vault.get(meta_key, collection=vault_service._coll)
Expand All @@ -796,7 +783,7 @@ async def remove(self, provider: str) -> None:
self._ensure_local_provider_admin_operation_allowed("remove", provider)
await self.revoke(provider)
if await self.is_local_provider(provider):
await self._vault.delete(provider, collection="providers")
await self._provider_definitions.delete(provider)
logger.info("Removed local provider definition: {}", provider)
else:
logger.info("Revoked bundled provider: {} (definition kept)", provider)
Expand Down
Loading
Loading