diff --git a/AGENTS.md b/AGENTS.md index 208c1ed0..c76a559e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -53,6 +53,19 @@ hatch test benchmark --benchmark-storage=file://benchmark/results just coverage ``` +### Running Examples + +Use `uv` to run examples from the `examples/` directory. Refer to the docstrings within each example file for specific commands. +Every example should include an example command of running a particular example with uvicorn. +```bash +# Run a basic query example +uv run examples/basic_query_example.py + +# Run an ASGI example with uvicorn +uv run --with "uvicorn[standard]" --with ariadne \ + uvicorn examples.basic_query_example:app --reload +``` + ## Code Style Requirements - **Python 3.10+** with type hints throughout @@ -96,4 +109,6 @@ Follow [Conventional Commits](https://www.conventionalcommits.org/): 2. Ensure test coverage meets the 90% minimum requirement 3. Format code with `just fmt` 4. Verify type hints with `just types` -5. Write a clear commit message following the conventional commits format +5. Ensure the documentation is up-to-date +6. Write a clear commit message following the conventional commits format + diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c3298b7..88b12834 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable unreleased changes to this project will be documented in this file. For released versions, see the [Releases](https://github.com/mirumee/ariadne/releases) page. -## Unreleased - +## 1.1.0a3 (2026-05-06) +======= +### ✨ New Features +- Automated SQLAlchemy integration diff --git a/README.md b/README.md index dbccecb2..e0db6686 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ Documentation is available [here](https://ariadnegraphql.org). - Loading schema from `.graphql`, `.gql`, and `.graphqls` files. - ASGI and WSGI support, with integrations for Django, FastAPI, Flask, and Starlette. - Opt-in automatic resolvers mapping between `camelCase123` and `snake_case_123`. +- Automated integration with **SQLAlchemy 2.0** for zero-boilerplate resolvers and N+1 prevention. - [OpenTelemetry](https://opentelemetry.io/) extension for API monitoring. - Built-in [GraphiQL](https://github.com/graphql/graphiql) explorer for development and testing. - GraphQL syntax validation via `gql()` helper function. diff --git a/ariadne/contrib/sqlalchemy/__init__.py b/ariadne/contrib/sqlalchemy/__init__.py new file mode 100644 index 00000000..72806f4c --- /dev/null +++ b/ariadne/contrib/sqlalchemy/__init__.py @@ -0,0 +1,22 @@ +try: + from .dataloaders import LoaderRegistry, SQLAlchemyDataLoader + from .extension import SQLAlchemyDataLoaderExtension + from .objects import SQLAlchemyObjectType + from .query import SQLAlchemyQueryType + from .types import LoadStrategy + from .utils import auto_eager_load +except ImportError as ex: + raise ImportError( + "SQLAlchemy integration requires the 'sqlalchemy' and 'aiodataloader' " + "packages. Install them using 'pip install \"ariadne[sqlalchemy]\"'." + ) from ex + +__all__ = [ + "SQLAlchemyDataLoaderExtension", + "LoadStrategy", + "LoaderRegistry", + "SQLAlchemyObjectType", + "SQLAlchemyQueryType", + "SQLAlchemyDataLoader", + "auto_eager_load", +] diff --git a/ariadne/contrib/sqlalchemy/dataloaders.py b/ariadne/contrib/sqlalchemy/dataloaders.py new file mode 100644 index 00000000..580af87b --- /dev/null +++ b/ariadne/contrib/sqlalchemy/dataloaders.py @@ -0,0 +1,124 @@ +import inspect +import logging +from collections import defaultdict +from typing import Any + +from aiodataloader import DataLoader +from sqlalchemy import select, tuple_ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import RelationshipProperty, Session + +logger = logging.getLogger(__name__) + + +class SQLAlchemyDataLoader(DataLoader): + """ + DataLoader for SQLAlchemy relationships supporting: + - Composite Keys + - Many-to-Many (secondary tables) + - Result grouping via SQL columns (optimized) + """ + + def __init__( + self, + session: Session | AsyncSession, + relation_prop: RelationshipProperty, + cache: bool = True, + ): + super().__init__(cache=cache) + self.session = session + self.relation_prop = relation_prop + self.target_model = relation_prop.mapper.class_ + self.is_list = relation_prop.uselist + + # Identify local and remote columns (handles composite keys) + if relation_prop.secondary is not None: + self.local_cols = [ + lp.key + for lp, rp in relation_prop.synchronize_pairs + if lp.key is not None + ] + self.remote_cols = [ + rp.key + for lp, rp in relation_prop.synchronize_pairs + if rp.key is not None + ] + else: + self.local_cols = [ + c.key for c in relation_prop.local_columns if c.key is not None + ] + self.remote_cols = [ + c.key for c in relation_prop.remote_side if c.key is not None + ] + + self.secondary = relation_prop.secondary + + def get_query(self, keys: list[Any]): + """Builds query. Handles composite IN clause and M2M joins.""" + target_model = self.target_model + stmt = select(target_model) + + if self.secondary is not None: + stmt = stmt.join(self.secondary) + filter_cols = [self.secondary.c[k] for k in self.remote_cols] + else: + filter_cols = [getattr(target_model, k) for k in self.remote_cols] + + # Add the filtering columns to the result to allow grouping + stmt = stmt.add_columns(*filter_cols) + + if len(filter_cols) > 1: + stmt = stmt.where(tuple_(*filter_cols).in_(keys)) + else: + # Flatten keys if they are single-element tuples + flat_keys = [k[0] if isinstance(k, (list, tuple)) else k for k in keys] + stmt = stmt.where(filter_cols[0].in_(flat_keys)) + + return stmt + + async def batch_load_fn(self, keys: list[Any]) -> list[Any]: + logger.debug( + "SQLAlchemyRelationLoader: Fetching %s for %d parents", + self.target_model.__name__, + len(keys), + ) + stmt = self.get_query(keys) + + result = self.session.execute(stmt) + if inspect.isawaitable(result): + result = await result + + rows = result.all() # type: ignore + + num_filter_cols = len(self.remote_cols) + grouped = defaultdict(list) + + for row in rows: + item = row[0] + # The filter columns are appended after the model instance + key_parts = row[1 : 1 + num_filter_cols] + key = tuple(key_parts) if num_filter_cols > 1 else key_parts[0] + grouped[key].append(item) + + return [ + grouped[k] if self.is_list else (grouped[k][0] if grouped[k] else None) + for k in keys + ] + + +class LoaderRegistry: + def __init__(self, session: Session | AsyncSession): + self.session = session + self._loaders: dict[ + tuple[RelationshipProperty, type[DataLoader]], DataLoader + ] = {} + + def get_loader( + self, + relation_prop: RelationshipProperty, + loader_class: type[SQLAlchemyDataLoader] = SQLAlchemyDataLoader, + ) -> DataLoader: + key = (relation_prop, loader_class) + if key not in self._loaders: + self._loaders[key] = loader_class(self.session, relation_prop) + return self._loaders[key] diff --git a/ariadne/contrib/sqlalchemy/extension.py b/ariadne/contrib/sqlalchemy/extension.py new file mode 100644 index 00000000..92912d6d --- /dev/null +++ b/ariadne/contrib/sqlalchemy/extension.py @@ -0,0 +1,26 @@ +from typing import Any + +from ...types import Extension +from .dataloaders import LoaderRegistry + + +class SQLAlchemyDataLoaderExtension(Extension): + """Ariadne extension that creates a per-request `LoaderRegistry`. + + Wires the SQLAlchemy DataLoader fallback path automatically: at the start + of each GraphQL request, reads the session from `context[session_key]` + and writes a fresh `LoaderRegistry(session)` to `context[registry_key]`. + + """ + + def __init__( + self, + *, + session_key: str = "session", + registry_key: str = "loader_registry", + ): + self.session_key = session_key + self.registry_key = registry_key + + def request_started(self, context: Any) -> None: + context[self.registry_key] = LoaderRegistry(context[self.session_key]) diff --git a/ariadne/contrib/sqlalchemy/objects.py b/ariadne/contrib/sqlalchemy/objects.py new file mode 100644 index 00000000..cc238010 --- /dev/null +++ b/ariadne/contrib/sqlalchemy/objects.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, cast + +from graphql import GraphQLObjectType, GraphQLSchema +from sqlalchemy import select +from sqlalchemy.orm import DeclarativeBase, RelationshipProperty, class_mapper + +from ...objects import ObjectType +from .dataloaders import LoaderRegistry +from .types import LoadStrategy + + +class SQLAlchemyObjectType(ObjectType): + """ + ObjectType specialized for SQLAlchemy models. + Automatically binds resolvers for relationships using DataLoaders. + """ + + model: type[DeclarativeBase] + aliases: dict[str, str] + strategies: dict[str, LoadStrategy] + max_depth: int + _registry_key: str + + def __init__( + self, + name: str, + model: type[DeclarativeBase], + *, + aliases: dict[str, str] | Callable[[], dict[str, str]] | None = None, + strategies: dict[str, LoadStrategy] | None = None, + max_depth: int = 3, + ): + super().__init__(name) + self.model = model + self.aliases = aliases() if callable(aliases) else (aliases or {}) # ty: ignore[call-top-callable] + self.strategies = strategies or {} + self.max_depth = max_depth + + def bind_to_schema(self, schema: GraphQLSchema) -> None: + """Binds this `SQLAlchemyObjectType` to the GraphQL schema. + + Auto-generates resolvers for the model's relationships and aliased + columns, then delegates to `ObjectType.bind_to_schema` to wire them + (along with any explicitly-set resolvers) onto the schema's fields. + + The auto-resolvers must be registered before calling `super()` so + they are included when the parent iterates `self._resolvers` to + populate the GraphQL type's field `resolve` attributes. + """ + graphql_type = schema.type_map.get(self.name) + self.validate_graphql_type(graphql_type) + self._bind_auto_resolvers(cast(GraphQLObjectType, graphql_type)) + super().bind_to_schema(schema) + + def get_base_query(self, info: Any, **kwargs: Any): + """ + Returns the base SQLAlchemy select statement for root queries. + Can be overridden to apply default filters. + """ + return select(self.model) + + def _bind_auto_resolvers(self, graphql_type: GraphQLObjectType) -> None: + schema_fields = graphql_type.fields + mapper = class_mapper(self.model) + + for gql_field, db_attr in self.aliases.items(): + if gql_field not in schema_fields: + continue + if callable(db_attr): + self.set_field(gql_field, db_attr) + else: + self.set_field( + gql_field, lambda obj, *_, _attr=db_attr: getattr(obj, _attr) + ) + + for relation in mapper.relationships: + if relation.key not in schema_fields: + continue + if relation.key in self._resolvers: + continue + self.set_field(relation.key, self._create_relation_resolver(relation)) + + @staticmethod + def get_loader_registry_from_context(context: Any) -> LoaderRegistry: + """Get the `LoaderRegistry` from the GraphQL context. + + Override this method to customize how the registry is retrieved. + """ + try: + return context["loader_registry"] + except KeyError: + raise RuntimeError( + "LoaderRegistry not found in context under key 'loader_registry'" + ) + + def _create_relation_resolver(self, relation: RelationshipProperty): + async def resolve(obj: Any, info: Any, **kwargs: Any): + # If the attribute is already loaded (e.g. via joinedload/selectinload), + # return it + if relation.key in obj.__dict__: + return getattr(obj, relation.key) + + loader_registry = self.get_loader_registry_from_context(info.context) + + # Identify which column(s) on the current object connect it to the + # target table. For a One-to-Many, this is usually a Foreign Key. + local_relation_columns = [ + c.key for c in relation.local_columns if c.key is not None + ] + + # Extract the actual database values from this specific object instance. + join_values = tuple(getattr(obj, col) for col in local_relation_columns) + + # If it's a standard single-column relationship, unwrap the tuple to just + # the ID. If it's a composite key, keep the tuple. + lookup_key = join_values[0] if len(join_values) == 1 else join_values + + loader = loader_registry.get_loader(relation) + return await loader.load(lookup_key) + + return resolve diff --git a/ariadne/contrib/sqlalchemy/query.py b/ariadne/contrib/sqlalchemy/query.py new file mode 100644 index 00000000..a3c3baa7 --- /dev/null +++ b/ariadne/contrib/sqlalchemy/query.py @@ -0,0 +1,89 @@ +import inspect +from collections.abc import Sequence +from typing import Any + +from graphql import GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLSchema +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session + +from ...objects import QueryType +from .objects import SQLAlchemyObjectType +from .utils import auto_eager_load + + +class SQLAlchemyQueryType(QueryType): + """ + A custom Query type that automatically binds SQLAlchemy resolvers + by inspecting the GraphQLSchema during the make_executable_schema build phase. + """ + + def __init__( + self, + object_types: Sequence[SQLAlchemyObjectType], + ): + super().__init__() + self.object_types = {ot.name: ot for ot in object_types} + self._object_types_by_model = {ot.model: ot for ot in object_types} + + @staticmethod + def get_session_from_context(context: Any) -> Session | AsyncSession: + try: + return context["session"] + except KeyError: + raise RuntimeError("Session not found in context under key 'session'") + + def bind_to_schema(self, schema: GraphQLSchema) -> None: + graphql_type = schema.type_map.get(self.name) + if not isinstance(graphql_type, GraphQLObjectType): + super().bind_to_schema(schema) + return + + for field_name, field_def in graphql_type.fields.items(): + is_list = False + unwrapped_type = field_def.type + + while isinstance(unwrapped_type, (GraphQLList, GraphQLNonNull)): + if isinstance(unwrapped_type, GraphQLList): + is_list = True + unwrapped_type = unwrapped_type.of_type + + type_name = getattr(unwrapped_type, "name", None) + + if type_name in self.object_types and field_name not in self._resolvers: + obj_type = self.object_types[type_name] + self.set_field( + field_name, self._create_auto_resolver(obj_type, is_list) + ) + + super().bind_to_schema(schema) + + def _create_auto_resolver(self, obj_type: SQLAlchemyObjectType, return_list: bool): + async def auto_resolve(obj: Any, info: Any, **kwargs: Any): + session = self.get_session_from_context(info.context) + + model = obj_type.model + stmt = obj_type.get_base_query(info, **kwargs) + + stmt = auto_eager_load( + stmt, + info, + model, + strategies=obj_type.strategies, + aliases=obj_type.aliases, + max_depth=obj_type.max_depth, + type_registry=self._object_types_by_model, + ) + + for key, value in kwargs.items(): + db_col_name = obj_type.aliases.get(key, key) + if hasattr(model, db_col_name): + stmt = stmt.where(getattr(model, db_col_name) == value) + + result = session.execute(stmt) + if inspect.isawaitable(result): + result = await result + if return_list: + return result.scalars().unique().all() # type: ignore + return result.scalars().first() # type: ignore + + return auto_resolve diff --git a/ariadne/contrib/sqlalchemy/types.py b/ariadne/contrib/sqlalchemy/types.py new file mode 100644 index 00000000..921caf3d --- /dev/null +++ b/ariadne/contrib/sqlalchemy/types.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from sqlalchemy.orm.strategy_options import _AbstractLoad + + +class LoadStrategy(Protocol): + """ + SQLAlchemy relationship loading strategy functions (``joinedload``, + ``selectinload``, ``subqueryload``, ``lazyload``, ``raiseload``, ``noload``, + ``immediateload``, ``contains_eager``, ``defaultload``). + """ + + @property + def __name__(self) -> str: ... + + def __call__(self, *args: Any, **kwargs: Any) -> _AbstractLoad: ... diff --git a/ariadne/contrib/sqlalchemy/utils.py b/ariadne/contrib/sqlalchemy/utils.py new file mode 100644 index 00000000..263cd7fe --- /dev/null +++ b/ariadne/contrib/sqlalchemy/utils.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from graphql import GraphQLError +from sqlalchemy.orm import class_mapper, joinedload, load_only, selectinload + +from .types import LoadStrategy + +if TYPE_CHECKING: + from collections.abc import Sequence + + from sqlalchemy.orm import Mapper, RelationshipProperty + from sqlalchemy.orm.strategy_options import _AbstractLoad + + from .objects import SQLAlchemyObjectType + +logger = logging.getLogger(__name__) + + +def _resolve_load_option( + mapper: Mapper[Any], + db_attr: str, + strategy: LoadStrategy | None, + rel: RelationshipProperty[Any], + load_path: _AbstractLoad | None, +) -> _AbstractLoad: + """Resolve a SQLAlchemy loading strategy for a relationship attribute. + + Args: + mapper: The SQLAlchemy mapper for the current model. + db_attr: The database attribute name of the relationship. + strategy: A SQLAlchemy loader function (e.g. ``selectinload``, + ``joinedload``, ``subqueryload``). When ``None``, defaults to + ``selectinload`` for collections and ``joinedload`` for scalars. + rel: The SQLAlchemy relationship property being loaded. + load_path: The parent loading chain to nest this option under. + At the root level this is ``None`` and the strategy is called + directly (e.g. ``selectinload(Post.tags)``). For nested + relationships it is the ``_AbstractLoad`` returned by the + parent strategy call, and the corresponding method is chained + onto it (e.g. ``selectinload(Post.tags).selectinload(Tag.posts)``). + + Returns: + A SQLAlchemy load option that can be passed to ``Query.options()``. + """ + attr = getattr(mapper.class_, db_attr) + if strategy is None: + strategy = selectinload if rel.uselist else joinedload + if load_path is not None: + method_name: str = getattr(strategy, "__name__") + return getattr(load_path, method_name)(attr) + return strategy(attr) + + +def _build_options( + mapper: Mapper[Any], + selections: Sequence[Any], + strategies: dict[str, LoadStrategy], + aliases: dict[str, str], + type_depths: dict[type[Any], int], + load_path: _AbstractLoad | None = None, + type_registry: dict[type[Any], SQLAlchemyObjectType] | None = None, +) -> list[_AbstractLoad]: + current_type = mapper.class_ + current_depth = type_depths.get(current_type, 0) + + # Look up this type's config from registry + type_config = type_registry.get(current_type) if type_registry else None + max_depth = type_config.max_depth if type_config else 3 + + if current_depth > max_depth: + type_name = current_type.__name__ + raise GraphQLError( + f"Query exceeds max_depth={max_depth} for type '{type_name}'. " + f"Current depth: {current_depth}." + ) + + options = [] + + # Process scalar fields using load_only + scalar_fields = [] + for s in selections: + gql_field = s.name.value + db_attr = aliases.get(gql_field, gql_field) + if ( + db_attr not in mapper.relationships + and hasattr(mapper.class_, db_attr) + and not callable(getattr(mapper.class_, db_attr)) + ): + scalar_fields.append(db_attr) + + if scalar_fields: + # Ensure local columns of all relationships (e.g., Foreign Keys) + # are also loaded to prevent N+1 queries when falling back to DataLoaders. + for rel in mapper.relationships.values(): + for col in rel.local_columns: + if ( + col.key is not None + and hasattr(mapper.class_, col.key) + and col.key not in scalar_fields + ): + scalar_fields.append(col.key) + + class_attrs = [getattr(mapper.class_, s) for s in scalar_fields] + if load_path is not None: + options.append(load_path.load_only(*class_attrs)) + else: + options.append(load_only(*class_attrs)) + + # Process relationships + for field_node in selections: + gql_field = field_node.name.value + db_attr = aliases.get(gql_field, gql_field) + + if db_attr in mapper.relationships: + rel = mapper.relationships[db_attr] + target_type = rel.mapper.class_ + + # Look up child type's config + child_config = type_registry.get(target_type) if type_registry else None + child_strategies = child_config.strategies if child_config else strategies + child_aliases = child_config.aliases if child_config else {} + + # Increment depth for target type + new_type_depths = type_depths.copy() + new_type_depths[target_type] = new_type_depths.get(target_type, 0) + 1 + + strategy = strategies.get(gql_field) + opt = _resolve_load_option(mapper, db_attr, strategy, rel, load_path) + options.append(opt) + + if field_node.selection_set: + nested_options = _build_options( + rel.mapper, + field_node.selection_set.selections, + child_strategies, + child_aliases, + new_type_depths, + load_path=opt, + type_registry=type_registry, + ) + options.extend(nested_options) + + return options + + +def auto_eager_load( + query: Any, + info: Any, + model: type[Any], + strategies: dict[str, LoadStrategy] | None = None, + aliases: dict[str, str] | None = None, + max_depth: int = 3, + type_registry: dict[type[Any], SQLAlchemyObjectType] | None = None, +) -> Any: + """Automatically apply eager loading options based on the GraphQL selection set. + + Inspects the incoming GraphQL query and adds ``selectinload()``, + ``joinedload()``, and ``load_only()`` (or any other SQLAlchemy loading + strategy) for fields found in the selection set. + + Depth is tracked per-type: each type counts how many times it has been + entered from the root. When a type's depth exceeds its ``max_depth``, + a ``GraphQLError`` is raised. + """ + resolved_strategies: dict[str, LoadStrategy] = strategies or {} + resolved_aliases: dict[str, str] = aliases or {} + mapper = class_mapper(model) + selections = [] + + for field_node in info.field_nodes: + if field_node.selection_set: + selections.extend(field_node.selection_set.selections) + + if not selections: + return query + + type_depths = {model: 1} + + options = _build_options( + mapper, + selections, + resolved_strategies, + resolved_aliases, + type_depths, + type_registry=type_registry, + ) + if options: + query = query.options(*options) + + return query diff --git a/ariadne/contrib/tracing/utils.py b/ariadne/contrib/tracing/utils.py index 7dcfc29c..484ba633 100644 --- a/ariadne/contrib/tracing/utils.py +++ b/ariadne/contrib/tracing/utils.py @@ -59,9 +59,10 @@ def repr_upload_file(upload_file: UploadFile | File) -> str: def format_path(path: ResponsePath): elements = [] - while path: - elements.append(path.key) - path = path.prev + current: ResponsePath | None = path + while current: + elements.append(current.key) + current = current.prev return elements[::-1] diff --git a/ariadne/objects.py b/ariadne/objects.py index 9454c0fe..1a0d2760 100644 --- a/ariadne/objects.py +++ b/ariadne/objects.py @@ -212,7 +212,7 @@ def set_field(self, name, resolver: Resolver) -> Resolver: self._resolvers[name] = resolver return resolver - def set_alias(self, name: str, to: str) -> None: + def set_alias(self, name: str, to: str | Callable) -> None: """Set an alias resolver for the field name to given Python name. # Required arguments @@ -220,9 +220,13 @@ def set_alias(self, name: str, to: str) -> None: `name`: a `str` with a name of the GraphQL object's field in GraphQL schema to set this resolver for. - `to`: a `str` of an attribute or dict key to resolve this field to. + `to`: a `str` of an attribute or dict key to resolve this field to, + or a `Callable`. """ - self._resolvers[name] = resolve_to(to) + if callable(to): + self._resolvers[name] = to + else: + self._resolvers[name] = resolve_to(to) def bind_to_schema(self, schema: GraphQLSchema) -> None: """Binds this `ObjectType` instance to the instance of GraphQL schema. @@ -258,11 +262,11 @@ def bind_resolvers_to_graphql_type(self, graphql_type, replace_existing=True): class QueryType(ObjectType): - """An convenience class for defining Query type. + """A convenience class for defining Query type. # Example - Both of those code samples have same effect: + Both of those code samples have the same effect: ```python query_type = QueryType() @@ -279,11 +283,11 @@ def __init__(self) -> None: class MutationType(ObjectType): - """An convenience class for defining Mutation type. + """A convenience class for defining Mutation type. # Example - Both of those code samples have same result: + Both of those code samples have the same result: ```python mutation_type = MutationType() diff --git a/ariadne/schema_visitor.py b/ariadne/schema_visitor.py index 73c6d64f..90f63dca 100644 --- a/ariadne/schema_visitor.py +++ b/ariadne/schema_visitor.py @@ -725,9 +725,8 @@ def heal_type( # any `GraphQLNamedType` with a `name`, then it must end up identical # to `schema.get_type(name)`, since `schema.type_map` is the source # of truth for all named schema types. - named_type = cast(GraphQLNamedType, type_) - official_type = schema.get_type(named_type.name) - if official_type and named_type != official_type: + official_type = schema.get_type(type_.name) + if official_type and type_ != official_type: return official_type return type_ diff --git a/docs/01-Docs/13-dataloaders.md b/docs/01-Docs/13-dataloaders.md index 197a505a..aa56bc98 100644 --- a/docs/01-Docs/13-dataloaders.md +++ b/docs/01-Docs/13-dataloaders.md @@ -492,4 +492,12 @@ async def resolve_move_category_contents(_, info, **kwargs): return {"success": True} ``` -Unlike `DataLoader`, `SyncDataLoader` doesn't provide an API for clearing entire cache or priming objects. \ No newline at end of file +Unlike `DataLoader`, `SyncDataLoader` doesn't provide an API for clearing entire cache or priming objects. + +## SQLAlchemy Integration + +If your project uses **SQLAlchemy 2.0** on an **async** stack (ASGI + `AsyncSession`), Ariadne provides an optional `ariadne.contrib.sqlalchemy` package that automates the creation of DataLoaders and relationship resolvers, ensuring "zero-boilerplate" N+1 prevention with advanced eager loading (lookahead optimization) support. + +The contrib is async-only — it builds on `aiodataloader` and does not work with `graphql_sync` / `SyncDataLoader`. For sync stacks, use the manual sync DataLoader pattern shown above. + +See the [SQLAlchemy Integration](../07-Contrib/02-sqlalchemy.md) guide for details. \ No newline at end of file diff --git a/docs/01-Docs/17-bindables.md b/docs/01-Docs/17-bindables.md index db7ca1dd..b59a83df 100644 --- a/docs/01-Docs/17-bindables.md +++ b/docs/01-Docs/17-bindables.md @@ -28,3 +28,8 @@ class MyCustomType(SchemaBindable): ``` `bind_to_schema` is called during executable schema creation. + +## Built-in Custom Bindables + +Ariadne provides advanced custom bindables as part of its integrations: +- `SQLAlchemyObjectType` and `SQLAlchemyQueryType` (see [SQLAlchemy Integration](../07-Contrib/02-sqlalchemy.md)) demonstrate how custom bindables can automate resolver generation and schema inspection. diff --git a/docs/07-Contrib/02-sqlalchemy.md b/docs/07-Contrib/02-sqlalchemy.md new file mode 100644 index 00000000..f2c76751 --- /dev/null +++ b/docs/07-Contrib/02-sqlalchemy.md @@ -0,0 +1,257 @@ +--- +id: sqlalchemy +title: SQLAlchemy Integration +sidebar_label: SQLAlchemy +--- + +Ariadne provides an optional integration for [SQLAlchemy 2.0](https://www.sqlalchemy.org/) that simplifies building GraphQL APIs on top of SQLAlchemy models. + +The integration has two execution paths that work together: + +1. **The `auto_eager_load` path** - `SQLAlchemyQueryType` builds *one* optimised SQL statement per top-level field by walking the GraphQL selection set ahead of time and applying `selectinload` / `joinedload` / `load_only` to whatever the client asked for. Most queries against an auto-resolved schema are fully served by this path. +2. **The DataLoader fallback path** - when a relationship is reached on a parent object that was *not* prepared by `auto_eager_load` (e.g. an object returned by a manual `@field` resolver), `SQLAlchemyObjectType` falls back to a `SQLAlchemyDataLoader` that batches the loads and prevents N+1. + +This document is organised around those two paths, both of which are covered here end-to-end. + +## Installation + +Install `ariadne` with the `sqlalchemy` extra: + +```console +pip install ariadne[sqlalchemy] +``` + +This installs `sqlalchemy` and `aiodataloader`. + + +> **Note:** The examples in this guide focus on the Ariadne integration and omit the infrastructure for database session management (such as middleware or extensions for automatic session creation and teardown). In production, you should use a framework-specific middleware or an Ariadne `Extension` to ensure sessions are correctly scoped to the request and closed after execution. + + +## Quick Start + +The minimal correct setup uses a synchronous `Session` and puts it in the GraphQL context. + +```python +from ariadne import make_executable_schema +from ariadne.asgi import GraphQL +from ariadne.asgi.handlers import GraphQLHTTPHandler +from ariadne.contrib.sqlalchemy import ( + SQLAlchemyObjectType, + SQLAlchemyQueryType, +) +from sqlalchemy import Column, ForeignKey, Integer, String, Table, create_engine +from sqlalchemy.orm import Session, declarative_base, relationship, sessionmaker + +Base = declarative_base() + +post_tags = Table( + "post_tags", + Base.metadata, + Column("post_id", Integer, ForeignKey("posts.id"), primary_key=True), + Column("tag_id", Integer, ForeignKey("tags.id"), primary_key=True), +) + + +class User(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True) + username = Column(String) + posts = relationship("Post", back_populates="author") + + +class Post(Base): + __tablename__ = "posts" + id = Column(Integer, primary_key=True) + title = Column(String) + author_id = Column(Integer, ForeignKey("users.id")) + author = relationship("User", back_populates="posts") + tags = relationship("Tag", secondary=post_tags, back_populates="posts") + + +class Tag(Base): + __tablename__ = "tags" + id = Column(Integer, primary_key=True) + name = Column(String) + posts = relationship("Post", secondary=post_tags, back_populates="tags") + + +type_defs = """ + type Query { + users: [User!]! + posts: [Post!]! + tags: [Tag!]! + } + + type User { + id: ID! + username: String! + posts: [Post!]! + } + + type Post { + id: ID! + title: String! + author: User! + tags: [Tag!]! + } + + type Tag { + id: ID! + name: String! + posts: [Post!]! + } +""" + +user_type = SQLAlchemyObjectType("User", User) +post_type = SQLAlchemyObjectType("Post", Post) +tag_type = SQLAlchemyObjectType("Tag", Tag) +query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + +schema = make_executable_schema(type_defs, [query, user_type, post_type, tag_type]) + +engine = create_engine("sqlite:///db.sqlite3") +SessionLocal = sessionmaker(engine, expire_on_commit=False) + + +async def get_context(request, _data): + return {"request": request, "session": SessionLocal()} + + +app = GraphQL( + schema, + context_value=get_context, + http_handler=GraphQLHTTPHandler(), +) +``` + +`SQLAlchemyQueryType` reads the session from `info.context["session"]` and feeds it through the `auto_eager_load` path described below. Ariadne's `context_value` callable must return the context dict (or an awaitable that resolves to one) - async generator / `yield`-based forms are not supported, so deterministic per-request session cleanup is best handled via a custom `Extension` (the same mechanism `SQLAlchemyDataLoaderExtension` uses). Without one, the session is closed when its connection is returned to the pool by GC. + +A complete runnable version of this Quick Start (in-memory SQLite, seed data, sample queries) lives in [`examples/sqlalchemy/01_auto_eager_load.py`](https://github.com/mirumee/ariadne/tree/main/examples/sqlalchemy/01_auto_eager_load.py). + +### Reading the session from somewhere other than `context["session"]` + +A common production pattern is to open the SQLAlchemy session in middleware (so the same scope handles both the GraphQL request and any other endpoints), attach it to `request.state.session`, and let resolvers read from there. `SQLAlchemyQueryType` looks the session up via `get_session_from_context`, which is a static method you can override on a subclass. Use that subclass everywhere you would have used `SQLAlchemyQueryType`: + +```python + +class MyType(SQLAlchemyQueryType): + @staticmethod + def get_session_from_context(context): + return context["request"].state.session + +query = MyType([user_type, post_type]) +``` + +### Tuning per-type behaviour: `aliases`, `strategies`, `max_depth` + +`SQLAlchemyObjectType` accepts three keyword-only arguments that change how its instance behaves on the `auto_eager_load` path: + +- **`aliases`** — map a GraphQL field name to a different SQLAlchemy attribute. Honoured by both relationship resolution and the `load_only` column optimisation. Pass a dict, or a zero-arg callable returning a dict for lazy initialisation. +- **`strategies`** — override the default loader strategy (`selectinload` for collections, `joinedload` for scalars) on a per-relationship basis. Any SQLAlchemy loader function works — `selectinload`, `joinedload`, `subqueryload`, etc. +- **`max_depth`** — cap how deep `auto_eager_load` walks into this type from the root. Tracked **per-type**: each entry into the same type counts. Exceeding it raises `GraphQLError`. Defaults to `3`. + +Minimal example covering all three: + +```python +from sqlalchemy.orm import selectinload + +post_type = SQLAlchemyObjectType( + "Post", + Post, + aliases={"my_post_id": "post_id"}, + strategies={"author": selectinload, "tags": selectinload}, + max_depth=4, +) + +query = SQLAlchemyQueryType([user_type, post_type]) +``` + + +----- +Custom resolvers and DataLoaders +----- + +### Required setup for DataLoaders + +For the DataLoader path to work, `loader_registry` must be present in `info.context`. The recommended way to put it there is the `SQLAlchemyDataLoaderExtension` — the bare class is a zero-arg callable, so you can pass it directly to the HTTP handler: + +```python +from ariadne.asgi import GraphQL +from ariadne.asgi.handlers import GraphQLHTTPHandler +from ariadne.contrib.sqlalchemy import SQLAlchemyDataLoaderExtension + +app = GraphQL( + schema, + context_value=get_context, # only puts "session" in context now + http_handler=GraphQLHTTPHandler(extensions=[SQLAlchemyDataLoaderExtension]), +) +``` + +With the extension enabled, your `get_context` only needs to put a `session` in the context — the extension creates a fresh `LoaderRegistry` per request before any resolver runs and writes it to `context["loader_registry"]`. + +To use different keys or a custom `LoaderRegistry` subclass pass proper arguments to the extension: + +```python +extensions=[ + SQLAlchemyDataLoaderExtension( + session_key="db", + registry_key="loaders", + ), +] +``` + +If you'd rather wire the registry in `get_context` yourself, that still works — the extension's only job is to do this for you: + +```python +from ariadne.contrib.sqlalchemy import LoaderRegistry + + +async def get_context(request, _data): + session = request.state.session + return { + "request": request, + "session": session, + "loader_registry": LoaderRegistry(session), + } +``` + + +The `auto_eager_load` path only fires for fields that go through `SQLAlchemyQueryType`'s auto-resolver. The moment you write a manual `@field` resolver that runs its own `select(...)` and returns ORM objects, those rows bypass `auto_eager_load` entirely — their relationships are not pre-loaded, and resolving `author`/`tags`/etc. on them falls through to a per-request DataLoader. The DataLoader collects every per-row lookup for one relationship into a single batched SQL statement, so N+1 is avoided here too — by batching rather than by lookahead. + +```python +@query.field("publishedPosts") +def resolve_published_posts(_, info): + session = info.context["session"] + stmt = select(Post).where(Post.is_published) + return session.execute(stmt).scalars().unique().all() +``` + +For the GraphQL query: + +```graphql +{ + publishedPosts { + title + author { username } + tags { name } + } +} +``` + +…the integration runs three SQL statements: the manual `WHERE is_published` query, then one batched `WHERE id IN (...)` for `author`, then one joined query through `post_tags` for `tags`. + +A runnable version of this scenario lives at [`examples/sqlalchemy/02_dataloader_fallback.py`](https://github.com/mirumee/ariadne/tree/main/examples/sqlalchemy/02_dataloader_fallback.py). Run it with `echo=True` and watch the SQL log to confirm the three-statement count. + + + +The registry **must** be created per request — sharing it across requests would leak DataLoader caches (and therefore data) between users. The simplest correct lifetime is "scoped to the same `Session` you put in the context". + +To customise the lookup, override `SQLAlchemyObjectType.get_loader_registry_from_context` on a subclass - the same pattern as `get_session_from_context`: + +```python +class MyObjectType(SQLAlchemyObjectType): + @staticmethod + def get_loader_registry_from_context(context): + return context["request"].state.loaders +``` + diff --git a/examples/sqlalchemy/01_auto_eager_load.py b/examples/sqlalchemy/01_auto_eager_load.py new file mode 100644 index 00000000..f3077654 --- /dev/null +++ b/examples/sqlalchemy/01_auto_eager_load.py @@ -0,0 +1,192 @@ +""" +This is the simplest correct setup of the SQLAlchemy integration. Every root +field on the schema is auto-resolved by `SQLAlchemyQueryType`, so the only +thing the GraphQL context needs is a SQLAlchemy `session`. + +Self-contained: a single file, an in-memory SQLite database, a synchronous +`Session`, and the schema/seed data inline. Run it and hit the endpoint with +queries like `{ posts { title author { username } } }` to see one optimised +SQL statement issued per top-level field. + +Note on async: SQLAlchemy's `AsyncSession` *can* be plugged in by swapping +`create_engine`/`sessionmaker` for `create_async_engine`/`async_sessionmaker`, +but two sibling root resolvers awaiting the same `AsyncSession` race on its +single underlying connection and SQLAlchemy raises +`InvalidRequestError: This session is provisioning a new connection; +concurrent operations are not permitted`. The synchronous `Session` used +here executes sequentially by definition and is unaffected. + +Run with: + + uv run \ + --with "uvicorn[standard]" \ + --with ariadne \ + --with "sqlalchemy" \ + --with "aiodataloader" \ + uvicorn examples.sqlalchemy.01_auto_eager_load:app --reload +""" + +from contextlib import asynccontextmanager + +from sqlalchemy import Column, ForeignKey, Integer, String, Table, create_engine +from sqlalchemy.orm import Session, declarative_base, relationship, sessionmaker +from sqlalchemy.pool import StaticPool +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware + +from ariadne import make_executable_schema +from ariadne.asgi import GraphQL +from ariadne.asgi.handlers import GraphQLHTTPHandler +from ariadne.contrib.sqlalchemy import ( + SQLAlchemyObjectType, + SQLAlchemyQueryType, +) + +# --- Database -------------------------------------------------------------- + +Base = declarative_base() + +post_tags = Table( + "post_tags", + Base.metadata, + Column("post_id", Integer, ForeignKey("posts.id"), primary_key=True), + Column("tag_id", Integer, ForeignKey("tags.id"), primary_key=True), +) + + +class User(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True) + username = Column(String, unique=True) + posts = relationship("Post", back_populates="author") + + +class Post(Base): + __tablename__ = "posts" + id = Column(Integer, primary_key=True) + title = Column(String) + author_id = Column(Integer, ForeignKey("users.id")) + author = relationship("User", back_populates="posts") + tags = relationship("Tag", secondary=post_tags, back_populates="posts") + + +class Tag(Base): + __tablename__ = "tags" + id = Column(Integer, primary_key=True) + name = Column(String) + posts = relationship("Post", secondary=post_tags, back_populates="tags") + + +engine = create_engine( + "sqlite:///:memory:", + echo=True, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +SessionLocal = sessionmaker(engine, expire_on_commit=False) + + +class DBSessionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + with SessionLocal() as session: + request.state.session = session + return await call_next(request) + + +def init_db() -> None: + Base.metadata.create_all(engine) + with Session(engine) as session: + alice = User(username="alice") + bob = User(username="bob") + python_tag = Tag(name="Python") + graphql_tag = Tag(name="GraphQL") + session.add_all( + [ + alice, + bob, + python_tag, + graphql_tag, + Post( + title="Hello, GraphQL", + author=alice, + tags=[python_tag, graphql_tag], + ), + Post(title="SQLAlchemy 2.0 tips", author=bob, tags=[graphql_tag]), + ] + ) + session.commit() + + +# --- GraphQL schema -------------------------------------------------------- + +type_defs = """ + type Query { + users: [User!]! + user(id: ID!): User + + posts: [Post!]! + post(id: ID!): Post + + tags: [Tag!]! + tag(id: ID!): Tag + } + + type User { + id: ID! + username: String! + posts: [Post!]! + } + + type Post { + id: ID! + title: String! + author: User! + tags: [Tag!]! + } + + type Tag { + id: ID! + name: String! + posts: [Post!]! + } +""" + +user_type = SQLAlchemyObjectType("User", User) +post_type = SQLAlchemyObjectType("Post", Post) +tag_type = SQLAlchemyObjectType("Tag", Tag) +query_type = SQLAlchemyQueryType([user_type, post_type, tag_type]) + +schema = make_executable_schema(type_defs, [query_type, user_type, post_type, tag_type]) + + +async def get_context(request, _data): + return {"request": request, "session": request.state.session} + + +graphql_app = GraphQL( + schema, + context_value=get_context, + http_handler=GraphQLHTTPHandler(), +) + + +@asynccontextmanager +async def lifespan(_app): + ## For testing purposes + init_db() + yield + + +app = Starlette( + debug=True, + lifespan=lifespan, + middleware=[Middleware(DBSessionMiddleware)], +) +app.mount("/", graphql_app) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/sqlalchemy/02_dataloader_fallback.py b/examples/sqlalchemy/02_dataloader_fallback.py new file mode 100644 index 00000000..ca68295f --- /dev/null +++ b/examples/sqlalchemy/02_dataloader_fallback.py @@ -0,0 +1,255 @@ +"""SQLAlchemy contrib example - Step 2: custom resolvers and DataLoaders. + +This example demonstrates the two scenarios in which the DataLoader fallback +fires (see "Custom resolvers and DataLoaders" in the docs): + +1. **A manual `@field` resolver returns ORM objects.** `publishedPosts` runs + its own `select(Post).where(Post.is_published)` and returns the rows + directly. They have empty `__dict__` for relationships, so resolving + `author`/`tags` on them goes through `LoaderRegistry`. + +2. **`Tag` is exposed in the schema but not registered with `SQLAlchemyQueryType`.** + Only `user_type` and `post_type` are passed to the QueryType below; + `tag_type` is defined (so its column/relationship resolvers exist on the + `Tag` GraphQL type) but omitted from the QueryType registration. + +What to watch in the SQL log (the engine runs with `echo=True`): + +* `{ posts { title author { username } tags { name } } }` — Step 1 path. Even + though `tag_type` isn't registered, `auto_eager_load` still applies + `selectinload(Post.tags)` using *default* per-type config (no custom + `max_depth`, `strategies`, or `aliases` for `Tag`). You'll see one SELECT + for posts with the relationships joined/selectin'd in. **No DataLoader.** + The practical effect of the missing registration is that `Tag`-specific + config is ignored, not that the DataLoader fires. +* `{ publishedPosts { title author { username } tags { name } } }` — Step 2 + path. The manual resolver bypasses lookahead, so each requested + relationship runs through the DataLoader: one batched `WHERE id IN (...)` + for `author` and one joined query through `post_tags` for `tags`, on top + of the manual `WHERE is_published` query. + +Self-contained: a single file, an in-memory SQLite database, a synchronous +`Session` (the integration also accepts an `AsyncSession`, but two sibling +root resolvers race on its single connection - see the docs caveat). + +Run with: + + uv run \ + --with "uvicorn[standard]" \ + --with ariadne \ + --with "sqlalchemy" \ + --with "aiodataloader" \ + uvicorn examples.sqlalchemy.02_dataloader_fallback:app --reload +""" + +from contextlib import asynccontextmanager + +from sqlalchemy import ( + Boolean, + Column, + ForeignKey, + Integer, + String, + Table, + create_engine, + select, +) +from sqlalchemy.orm import Session, declarative_base, relationship, sessionmaker +from sqlalchemy.pool import StaticPool +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware + +from ariadne import make_executable_schema +from ariadne.asgi import GraphQL +from ariadne.asgi.handlers import GraphQLHTTPHandler +from ariadne.contrib.sqlalchemy import ( + SQLAlchemyDataLoaderExtension, + SQLAlchemyObjectType, + SQLAlchemyQueryType, +) + +# --- Database -------------------------------------------------------------- + +Base = declarative_base() + +post_tags = Table( + "post_tags", + Base.metadata, + Column("post_id", Integer, ForeignKey("posts.id"), primary_key=True), + Column("tag_id", Integer, ForeignKey("tags.id"), primary_key=True), +) + + +class User(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True) + username = Column(String, unique=True) + posts = relationship("Post", back_populates="author") + + +class Post(Base): + __tablename__ = "posts" + id = Column(Integer, primary_key=True) + title = Column(String) + is_published = Column(Boolean, default=True) + author_id = Column(Integer, ForeignKey("users.id")) + author = relationship("User", back_populates="posts") + tags = relationship("Tag", secondary=post_tags, back_populates="posts") + + +class Tag(Base): + __tablename__ = "tags" + id = Column(Integer, primary_key=True) + name = Column(String) + posts = relationship("Post", secondary=post_tags, back_populates="tags") + + +engine = create_engine( + "sqlite:///:memory:", + echo=True, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +SessionLocal = sessionmaker(engine, expire_on_commit=False) + + +class DBSessionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + with SessionLocal() as session: + request.state.session = session + return await call_next(request) + + +def init_db() -> None: + Base.metadata.create_all(engine) + with Session(engine) as session: + alice = User(username="alice") + bob = User(username="bob") + python_tag = Tag(name="Python") + graphql_tag = Tag(name="GraphQL") + sqlalchemy_tag = Tag(name="SQLAlchemy") + session.add_all( + [ + alice, + bob, + python_tag, + graphql_tag, + sqlalchemy_tag, + Post( + title="Hello, GraphQL", + author=alice, + tags=[python_tag, graphql_tag], + is_published=True, + ), + Post( + title="SQLAlchemy 2.0 tips", + author=bob, + tags=[graphql_tag, sqlalchemy_tag], + is_published=True, + ), + Post( + title="Draft notes", + author=alice, + tags=[python_tag], + is_published=False, + ), + ] + ) + session.commit() + + +# --- GraphQL schema -------------------------------------------------------- + +type_defs = """ + type Query { + # auto-resolved, no DataLoader. + posts: [Post!]! + + # hit the DataLoader. + publishedPosts: [Post!]! + } + + type User { + id: ID! + username: String! + posts: [Post!]! + } + + type Post { + id: ID! + title: String! + isPublished: Boolean! + author: User! + tags: [Tag!]! + } + + type Tag { + id: ID! + name: String! + posts: [Post!]! + } +""" + +user_type = SQLAlchemyObjectType("User", User) +post_type = SQLAlchemyObjectType("Post", Post, aliases={"isPublished": "is_published"}) + +# `tag_type` is intentionally NOT passed to `SQLAlchemyQueryType` below, even +# though `Tag` is reachable in the schema via `Post.tags`. Defining +# `tag_type` keeps the per-relationship resolvers on `Tag` (so `Tag.posts` +# still works), but auto_eager_load won't honour any per-type config you +# attach here when traversing into `Tag`. +tag_type = SQLAlchemyObjectType("Tag", Tag) + +query_type = SQLAlchemyQueryType([user_type, post_type]) + + +@query_type.field("publishedPosts") +def resolve_published_posts(_, info): + session = info.context["session"] + stmt = select(Post).where(Post.is_published) + return session.execute(stmt).scalars().unique().all() + + +schema = make_executable_schema(type_defs, [query_type, user_type, post_type, tag_type]) + + +# --- Context --------------------------------------------------------------- + +# The `session` key is required for `SQLAlchemyQueryType`. +# `SQLAlchemyDataLoaderExtension` will automatically create the +# `loader_registry` in context for `SQLAlchemyObjectType` + + +async def get_context(request, _data): + return { + "request": request, + "session": request.state.session, + } + + +graphql_app = GraphQL( + schema, + context_value=get_context, + http_handler=GraphQLHTTPHandler(extensions=[SQLAlchemyDataLoaderExtension]), +) + + +@asynccontextmanager +async def lifespan(_app): + init_db() + yield + + +app = Starlette( + debug=True, + lifespan=lifespan, + middleware=[Middleware(DBSessionMiddleware)], +) +app.mount("/", graphql_app) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/pyproject.toml b/pyproject.toml index 8409ffcb..e575686a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,10 @@ test = [ ] asgi-file-uploads = ["python-multipart>=0.0.13"] telemetry = ["opentelemetry-api"] +sqlalchemy = [ + "sqlalchemy>=2.0.0", + "aiodataloader>=0.2.0", +] [project.urls] "Homepage" = "https://ariadnegraphql.org/" @@ -119,6 +123,9 @@ check = [ ## Types environment +[tool.hatch.envs.types] +features = ["types", "sqlalchemy"] + [tool.hatch.envs.types.scripts] check = "ty check" @@ -126,7 +133,7 @@ check = "ty check" ## Test environments [tool.hatch.envs.hatch-test] -features = ["test"] +features = ["test", "sqlalchemy"] extra-args = [] diff --git a/tests/sqlalchemy/__init__.py b/tests/sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py new file mode 100644 index 00000000..6d4acb80 --- /dev/null +++ b/tests/sqlalchemy/conftest.py @@ -0,0 +1,78 @@ +import pytest +from sqlalchemy import ( + Column, + ForeignKey, + ForeignKeyConstraint, + Integer, + String, + Table, +) +from sqlalchemy.orm import declarative_base, relationship + + +@pytest.fixture +def models(): + """ORM-only definitions mirroring examples/sqlalchemy/01_auto_eager_load.py. + + No engine, no DB - the tests mock `session.execute(...)` directly. Models + only need to expose introspectable RelationshipProperty objects. + """ + base = declarative_base() + + post_tags = Table( + "post_tags", + base.metadata, + Column("post_id", Integer, ForeignKey("posts.id"), primary_key=True), + Column("tag_id", Integer, ForeignKey("tags.id"), primary_key=True), + ) + + class User(base): + __tablename__ = "users" + id = Column(Integer, primary_key=True) + username = Column(String, unique=True) + posts = relationship("Post", back_populates="author") + + class Post(base): + __tablename__ = "posts" + id = Column(Integer, primary_key=True) + title = Column(String) + author_id = Column(Integer, ForeignKey("users.id")) + author = relationship("User", back_populates="posts") + tags = relationship("Tag", secondary=post_tags, back_populates="posts") + + class Tag(base): + __tablename__ = "tags" + id = Column(Integer, primary_key=True) + name = Column(String) + posts = relationship("Post", secondary=post_tags, back_populates="tags") + + return {"User": User, "Post": Post, "Tag": Tag, "post_tags": post_tags} + + +@pytest.fixture +def composite_key_models(): + """Composite-PK/FK models for testing the `tuple_(...).in_(...)` branch.""" + base = declarative_base() + + class Region(base): + __tablename__ = "regions" + country = Column(String, primary_key=True) + code = Column(String, primary_key=True) + name = Column(String) + cities = relationship("City", back_populates="region") + + class City(base): + __tablename__ = "cities" + id = Column(Integer, primary_key=True) + name = Column(String) + country = Column(String) + region_code = Column(String) + region = relationship("Region", back_populates="cities") + + __table_args__ = ( + ForeignKeyConstraint( + ["country", "region_code"], ["regions.country", "regions.code"] + ), + ) + + return {"Region": Region, "City": City} diff --git a/tests/sqlalchemy/test_dataloader_extension.py b/tests/sqlalchemy/test_dataloader_extension.py new file mode 100644 index 00000000..31afe6a3 --- /dev/null +++ b/tests/sqlalchemy/test_dataloader_extension.py @@ -0,0 +1,20 @@ +from unittest.mock import Mock + +import pytest + +from ariadne.contrib.sqlalchemy import LoaderRegistry, SQLAlchemyDataLoaderExtension + + +@pytest.fixture +def session(): + return Mock(name="session") + + +def test_request_started_creates_loader_registry(session): + ext = SQLAlchemyDataLoaderExtension() + context = {"session": session} + + ext.request_started(context) + + assert isinstance(context["loader_registry"], LoaderRegistry) + assert context["loader_registry"].session is session diff --git a/tests/sqlalchemy/test_dataloaders.py b/tests/sqlalchemy/test_dataloaders.py new file mode 100644 index 00000000..f8e5063d --- /dev/null +++ b/tests/sqlalchemy/test_dataloaders.py @@ -0,0 +1,386 @@ +from unittest.mock import AsyncMock, Mock + +import pytest +from aiodataloader import DataLoader +from sqlalchemy import inspect as sa_inspect + +from ariadne.contrib.sqlalchemy.dataloaders import ( + LoaderRegistry, + SQLAlchemyDataLoader, +) + + +def get_relation(model, name): + return sa_inspect(model).relationships[name] + + +def make_session(rows): + """Build a sync session whose `execute(stmt).all()` returns `rows`.""" + result = Mock() + result.all.return_value = rows + session = Mock() + session.execute.return_value = result + return session + + +def make_async_session(rows): + """Build an async-style session: `execute(...)` returns an awaitable.""" + result = Mock() + result.all.return_value = rows + session = Mock() + session.execute = AsyncMock(return_value=result) + return session + + +# --------------------------------------------------------------------------- +# SQLAlchemyDataLoader.__init__ +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestSQLAlchemyDataLoaderInit: + async def test_one_to_many_resolves_local_and_remote_columns(self, models): + session = Mock(name="session") + relation = get_relation(models["User"], "posts") + loader = SQLAlchemyDataLoader(session, relation) + + assert loader.session is session + assert loader.relation_prop is relation + assert loader.target_model is models["Post"] + assert loader.is_list is True + assert loader.secondary is None + assert loader.local_cols == ["id"] + assert loader.remote_cols == ["author_id"] + + async def test_many_to_one_marks_uselist_false(self, models): + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["Post"], "author")) + + assert loader.is_list is False + assert loader.target_model is models["User"] + assert loader.secondary is None + assert loader.local_cols == ["author_id"] + assert loader.remote_cols == ["id"] + + async def test_many_to_many_uses_secondary_synchronize_pairs(self, models): + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["Post"], "tags")) + + assert loader.is_list is True + assert loader.secondary is models["post_tags"] + # M2M follows synchronize_pairs (parent -> secondary): + # `local_cols` is the parent PK, `remote_cols` is the secondary's + # parent-side FK column - the column we filter and group on. + assert loader.local_cols == ["id"] + assert loader.remote_cols == ["post_id"] + + async def test_many_to_many_reverse_side(self, models): + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["Tag"], "posts")) + + assert loader.is_list is True + assert loader.secondary is models["post_tags"] + assert loader.local_cols == ["id"] + assert loader.remote_cols == ["tag_id"] + + async def test_default_cache_is_enabled(self, models): + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["User"], "posts")) + assert loader.cache is True + + async def test_cache_can_be_disabled(self, models): + loader = SQLAlchemyDataLoader( + Mock(), get_relation(models["User"], "posts"), cache=False + ) + assert loader.cache is False + + async def test_loader_is_an_aiodataloader(self, models): + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["User"], "posts")) + assert isinstance(loader, DataLoader) + + async def test_composite_key_relationship(self, composite_key_models): + loader = SQLAlchemyDataLoader( + Mock(), get_relation(composite_key_models["Region"], "cities") + ) + + assert loader.is_list is True + assert sorted(loader.local_cols) == sorted(["country", "code"]) + assert sorted(loader.remote_cols) == sorted(["country", "region_code"]) + + +# --------------------------------------------------------------------------- +# SQLAlchemyDataLoader.get_query +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestGetQuery: + def _compile(self, stmt): + return str(stmt.compile(compile_kwargs={"literal_binds": True})).lower() + + async def test_simple_relation_uses_in_clause(self, models): + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["User"], "posts")) + + sql = self._compile(loader.get_query([1, 2])) + + assert "from posts" in sql + assert "author_id in" in sql + assert "1" in sql and "2" in sql + + async def test_simple_relation_unwraps_tuple_keys(self, models): + """The dataloader accepts both scalars and 1-tuples for single-column + relationships; tuples get flattened into a scalar IN clause.""" + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["User"], "posts")) + + sql = self._compile(loader.get_query([(1,), (2,)])) + + # No tuple comparison - regular `IN (1, 2)`. + assert "in (1, 2)" in sql + + async def test_many_to_one_filters_target_pk(self, models): + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["Post"], "author")) + + sql = self._compile(loader.get_query([1, 2])) + + assert "from users" in sql + assert "users.id in" in sql + + async def test_secondary_query_joins_through_secondary(self, models): + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["Post"], "tags")) + + sql = self._compile(loader.get_query([10, 12])) + + assert "from tags" in sql + assert "join post_tags" in sql + # The dataloader filters on the secondary's parent-side FK column. + assert "post_tags.post_id in" in sql + + async def test_composite_key_query_uses_tuple_in(self, composite_key_models): + loader = SQLAlchemyDataLoader( + Mock(), get_relation(composite_key_models["Region"], "cities") + ) + + sql = self._compile(loader.get_query([("US", "CA"), ("UK", "LD")])) + + # tuple_(...).in_(...) renders as a multi-column IN. + assert "in (" in sql + assert "'us'" in sql and "'ca'" in sql + + async def test_filter_columns_are_appended_as_result_columns(self, models): + """The dataloader appends filter columns so it can group rows by key.""" + loader = SQLAlchemyDataLoader(Mock(), get_relation(models["User"], "posts")) + + stmt = loader.get_query([1]) + assert any("author_id" in str(col) for col in stmt.selected_columns) + + +# --------------------------------------------------------------------------- +# SQLAlchemyDataLoader.batch_load_fn (sync session) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestBatchLoadOneToMany: + async def test_groups_results_by_filter_column(self, models): + # Rows: (post_obj, author_id) - the dataloader groups by author_id. + post_a1 = Mock(name="post-a1") + post_a2 = Mock(name="post-a2") + post_b1 = Mock(name="post-b1") + + session = make_session([(post_a1, 1), (post_a2, 1), (post_b1, 2)]) + + loader = SQLAlchemyDataLoader(session, get_relation(models["User"], "posts")) + result = await loader.batch_load_fn([1, 2, 3]) + + assert result == [[post_a1, post_a2], [post_b1], []] + # The session was driven exactly once - that's the whole point of + # batching. + session.execute.assert_called_once() + + async def test_preserves_input_order(self, models): + post_a = Mock(name="post-a") + post_b = Mock(name="post-b") + session = make_session([(post_a, 1), (post_b, 2)]) + + loader = SQLAlchemyDataLoader(session, get_relation(models["User"], "posts")) + result = await loader.batch_load_fn([2, 3, 1]) + + assert result == [[post_b], [], [post_a]] + + async def test_load_many_dispatches_one_batch(self, models): + """Public DataLoader.load_many funnels through a single batch_load_fn + call - this is the contract Ariadne resolvers rely on.""" + post_a = Mock() + post_b = Mock() + session = make_session([(post_a, 1), (post_b, 2)]) + + loader = SQLAlchemyDataLoader(session, get_relation(models["User"], "posts")) + groups = await loader.load_many([1, 2, 3]) + + assert groups == [[post_a], [post_b], []] + session.execute.assert_called_once() + + +@pytest.mark.asyncio +class TestBatchLoadManyToOne: + async def test_returns_single_object_per_key(self, models): + alice = Mock(name="alice") + bob = Mock(name="bob") + session = make_session([(alice, 1), (bob, 2)]) + + loader = SQLAlchemyDataLoader(session, get_relation(models["Post"], "author")) + result = await loader.batch_load_fn([1, 2]) + + assert result == [alice, bob] + + async def test_returns_none_when_no_match(self, models): + alice = Mock(name="alice") + session = make_session([(alice, 1)]) + + loader = SQLAlchemyDataLoader(session, get_relation(models["Post"], "author")) + # Author id 999 has no row in the result set. + result = await loader.batch_load_fn([1, 999]) + + assert result == [alice, None] + + +@pytest.mark.asyncio +class TestBatchLoadManyToMany: + async def test_groups_through_secondary(self, models): + # Rows for Post.tags grouped by post_tags.post_id (the filter column). + python_tag = Mock(name="python") + graphql_tag = Mock(name="graphql") + session = make_session( + [ + (python_tag, 10), + (graphql_tag, 10), + (python_tag, 11), + (graphql_tag, 12), + ] + ) + + loader = SQLAlchemyDataLoader(session, get_relation(models["Post"], "tags")) + result = await loader.batch_load_fn([10, 11, 12, 13]) + + assert result == [ + [python_tag, graphql_tag], + [python_tag], + [graphql_tag], + [], + ] + + async def test_reverse_side_groups_correctly(self, models): + """Tag.posts goes through the same secondary in the opposite direction + - the dataloader should group by `tag_id`.""" + post_10 = Mock(name="post10") + post_11 = Mock(name="post11") + post_12 = Mock(name="post12") + session = make_session( + [ + (post_10, 1), + (post_11, 1), + (post_10, 2), + (post_12, 2), + ] + ) + + loader = SQLAlchemyDataLoader(session, get_relation(models["Tag"], "posts")) + result = await loader.batch_load_fn([1, 2, 3]) + + assert result == [[post_10, post_11], [post_10, post_12], []] + + +@pytest.mark.asyncio +class TestBatchLoadCompositeKeys: + async def test_groups_by_composite_key_tuple(self, composite_key_models): + sf = Mock(name="sf") + la = Mock(name="la") + buffalo = Mock(name="buf") + london = Mock(name="ldn") + + # Rows: (city, country, region_code) - two filter columns. + session = make_session( + [ + (sf, "US", "CA"), + (la, "US", "CA"), + (buffalo, "US", "NY"), + (london, "UK", "LD"), + ] + ) + + loader = SQLAlchemyDataLoader( + session, get_relation(composite_key_models["Region"], "cities") + ) + result = await loader.batch_load_fn( + [("US", "CA"), ("US", "NY"), ("UK", "LD"), ("FR", "PA")] + ) + + assert result == [[sf, la], [buffalo], [london], []] + + +# --------------------------------------------------------------------------- +# SQLAlchemyDataLoader.batch_load_fn (async session path) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestBatchLoadAsyncSession: + async def test_awaits_async_session_execute(self, models): + """When `session.execute(...)` returns an awaitable (as AsyncSession + does), the dataloader awaits it before consuming `.all()`.""" + post = Mock(name="post") + async_session = make_async_session([(post, 1)]) + + loader = SQLAlchemyDataLoader( + async_session, get_relation(models["User"], "posts") + ) + result = await loader.batch_load_fn([1]) + + async_session.execute.assert_awaited_once() + assert result == [[post]] + + +# --------------------------------------------------------------------------- +# LoaderRegistry +# --------------------------------------------------------------------------- +@pytest.mark.asyncio +class TestLoaderRegistry: + async def test_caches_loader_per_relationship(self, models): + registry = LoaderRegistry(Mock(name="session")) + relation = get_relation(models["User"], "posts") + + a = registry.get_loader(relation) + b = registry.get_loader(relation) + + assert a is b + assert isinstance(a, SQLAlchemyDataLoader) + + async def test_passes_session_to_loader(self, models): + session = Mock(name="session") + registry = LoaderRegistry(session) + + loader = registry.get_loader(get_relation(models["User"], "posts")) + + assert loader.session is session + + async def test_distinct_relationships_get_distinct_loaders(self, models): + registry = LoaderRegistry(Mock()) + + users_posts = registry.get_loader(get_relation(models["User"], "posts")) + post_author = registry.get_loader(get_relation(models["Post"], "author")) + + assert users_posts is not post_author + assert users_posts.target_model is models["Post"] + assert post_author.target_model is models["User"] + + async def test_custom_loader_class_is_keyed_separately(self, models): + class CustomLoader(SQLAlchemyDataLoader): + pass + + registry = LoaderRegistry(Mock()) + relation = get_relation(models["User"], "posts") + + default = registry.get_loader(relation) + custom = registry.get_loader(relation, loader_class=CustomLoader) + + assert default is not custom + assert type(default) is SQLAlchemyDataLoader + assert type(custom) is CustomLoader + # Subsequent lookups for the same (relation, class) hit the cache. + assert registry.get_loader(relation, loader_class=CustomLoader) is custom diff --git a/tests/sqlalchemy/test_objects.py b/tests/sqlalchemy/test_objects.py new file mode 100644 index 00000000..913cd90e --- /dev/null +++ b/tests/sqlalchemy/test_objects.py @@ -0,0 +1,340 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from graphql import ( + GraphQLField, + GraphQLInt, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, +) +from sqlalchemy import inspect as sa_inspect +from sqlalchemy.orm import selectinload + +from ariadne import make_executable_schema +from ariadne.contrib.sqlalchemy import ( + LoaderRegistry, + SQLAlchemyDataLoader, + SQLAlchemyObjectType, + SQLAlchemyQueryType, +) + + +def get_relation(model, name): + return sa_inspect(model).relationships[name] + + +# --------------------------------------------------------------------------- +# Construction defaults & overrides +# --------------------------------------------------------------------------- + + +class TestInit: + def test_defaults(self, models): + ot = SQLAlchemyObjectType("User", models["User"]) + + assert ot.name == "User" + assert ot.model is models["User"] + assert ot.aliases == {} + assert ot.strategies == {} + assert ot.max_depth == 3 + + def test_dict_aliases_are_stored_directly(self, models): + aliases = {"my_post_id": "post_id"} + ot = SQLAlchemyObjectType("Post", models["Post"], aliases=aliases) + + assert ot.aliases == aliases + + def test_strategies_and_max_depth_are_stored(self, models): + ot = SQLAlchemyObjectType( + "Post", + models["Post"], + strategies={"author": selectinload, "tags": selectinload}, + max_depth=4, + ) + + assert ot.strategies == {"author": selectinload, "tags": selectinload} + assert ot.max_depth == 4 + + +class TestBindAutoResolvers: + def _build_schema(self, *bindables): + type_defs = """ + type Query { + users: [User!]! + posts: [Post!]! + tags: [Tag!]! + } + + type User { + id: ID! + username: String! + posts: [Post!]! + } + + type Post { + id: ID! + title: String! + author: User! + tags: [Tag!]! + } + + type Tag { + id: ID! + name: String! + posts: [Post!]! + } + """ + return make_executable_schema(type_defs, list(bindables)) + + def test_relationship_fields_get_auto_resolvers(self, models): + user_type = SQLAlchemyObjectType("User", models["User"]) + post_type = SQLAlchemyObjectType("Post", models["Post"]) + tag_type = SQLAlchemyObjectType("Tag", models["Tag"]) + query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + + schema = self._build_schema(query, user_type, post_type, tag_type) + + # Every relationship field has a resolver bound to the GraphQL field. + post = schema.type_map["Post"] + assert post.fields["author"].resolve is not None + assert post.fields["tags"].resolve is not None + + user = schema.type_map["User"] + assert user.fields["posts"].resolve is not None + + # Auto-resolvers also get registered on the bindable's `_resolvers` + # dict so that any subsequent rebind sees them. + assert "author" in post_type._resolvers + assert "tags" in post_type._resolvers + assert "posts" in user_type._resolvers + + def test_existing_resolver_is_not_overwritten_by_auto_resolver(self, models): + post_type = SQLAlchemyObjectType("Post", models["Post"]) + + # User-defined resolver registered before bind_to_schema runs. + @post_type.field("author") + def custom_author(*_): # pragma: no cover - identity check only + return None + + user_type = SQLAlchemyObjectType("User", models["User"]) + tag_type = SQLAlchemyObjectType("Tag", models["Tag"]) + query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + + self._build_schema(query, user_type, post_type, tag_type) + + # The custom resolver wins - auto-binding skips fields already in + # `_resolvers`. + assert post_type._resolvers["author"] is custom_author + + def test_alias_for_field_not_in_schema_is_skipped(self, models): + post_type = SQLAlchemyObjectType( + "Post", models["Post"], aliases={"missing": "title"} + ) + + gql_type = GraphQLObjectType( + "Post", + { + "title": GraphQLField(GraphQLString), + "author": GraphQLField(GraphQLString), + "tags": GraphQLField(GraphQLList(GraphQLString)), + }, + ) + schema = GraphQLSchema(types=[gql_type]) + post_type.bind_to_schema(schema) + + assert "missing" not in post_type._resolvers + + def test_relationship_not_in_schema_is_skipped(self, models): + # Schema exposes Post but omits the `tags` field - auto-binding must + # not register a resolver for a field the schema doesn't define. + gql_type = GraphQLObjectType( + "Post", + { + "title": GraphQLField(GraphQLString), + "author": GraphQLField(GraphQLString), + }, + ) + schema = GraphQLSchema(types=[gql_type]) + post_type = SQLAlchemyObjectType("Post", models["Post"]) + post_type.bind_to_schema(schema) + + assert "tags" not in post_type._resolvers + assert "author" in post_type._resolvers + + def test_missing_type_in_schema_raises(self, models): + schema = GraphQLSchema( + query=GraphQLObjectType( + "Query", {"x": GraphQLField(GraphQLNonNull(GraphQLInt))} + ), + ) + post_type = SQLAlchemyObjectType("Post", models["Post"]) + + with pytest.raises(ValueError, match="Post"): + post_type.bind_to_schema(schema) + + +class TestGetLoaderRegistryFromContext: + def test_returns_value_from_default_key(self): + registry = Mock(name="registry") + assert ( + SQLAlchemyObjectType.get_loader_registry_from_context( + {"loader_registry": registry} + ) + is registry + ) + + def test_missing_key_raises_runtime_error(self): + with pytest.raises(RuntimeError, match="loader_registry"): + SQLAlchemyObjectType.get_loader_registry_from_context({}) + + def test_subclass_can_override_lookup(self, models): + class MyObjectType(SQLAlchemyObjectType): + @staticmethod + def get_loader_registry_from_context(context): + return context["request"].state.loaders + + registry = Mock(name="registry") + context = {"request": SimpleNamespace(state=SimpleNamespace(loaders=registry))} + + assert MyObjectType.get_loader_registry_from_context(context) is registry + + +class TestGetBaseQuery: + def test_returns_select_for_model(self, models): + ot = SQLAlchemyObjectType("Post", models["Post"]) + stmt = ot.get_base_query(info=Mock()) + + # `select(Post)` renders to `SELECT ... FROM posts`. + sql = str(stmt.compile()).lower() + assert "from posts" in sql + + +@pytest.mark.asyncio +class TestRelationResolver: + def _make_resolver(self, post_type, relation): + return post_type._create_relation_resolver(relation) + + async def test_returns_preloaded_value_without_touching_loader(self, models): + """If the relationship is already populated on the ORM instance (e.g. + via `selectinload`), the resolver must read it directly and not + consult the registry - that's the whole point of the auto_eager_load + fast path.""" + post_type = SQLAlchemyObjectType("Post", models["Post"]) + relation = get_relation(models["Post"], "author") + resolver = self._make_resolver(post_type, relation) + + author = SimpleNamespace(username="alice") + # Mimic SQLAlchemy populating the relationship: the attribute is + # present on the instance's __dict__. + post = SimpleNamespace(author=author, author_id=1) + + registry = Mock(spec=LoaderRegistry) + info = SimpleNamespace(context={"loader_registry": registry}) + + result = await resolver(post, info) + + assert result is author + registry.get_loader.assert_not_called() + + async def test_falls_through_to_loader_when_not_preloaded(self, models): + """Manual resolvers return ORM rows with empty relationship state. + The auto-bound resolver must look the relationship up via the + per-request `LoaderRegistry`.""" + post_type = SQLAlchemyObjectType("Post", models["Post"]) + relation = get_relation(models["Post"], "author") + resolver = self._make_resolver(post_type, relation) + + loader = Mock(spec=SQLAlchemyDataLoader) + loader.load = Mock(return_value=_awaitable("alice")) + registry = Mock(spec=LoaderRegistry) + registry.get_loader.return_value = loader + + # __dict__ has the FK column but NOT the relationship attribute. + post = SimpleNamespace(author_id=42) + # Prune `author` from the namespace's auto-attributes so the + # `relation.key in obj.__dict__` check in objects.py returns False. + post.__dict__.pop("author", None) + + info = SimpleNamespace(context={"loader_registry": registry}) + result = await resolver(post, info) + + assert result == "alice" + registry.get_loader.assert_called_once_with(relation) + loader.load.assert_called_once_with(42) + + async def test_unwraps_single_column_key(self, models): + """Single-column FK → loader is called with the scalar value, not a + 1-tuple. (`SQLAlchemyDataLoader` accepts both, but the relation + resolver standardises on the unwrapped form.)""" + post_type = SQLAlchemyObjectType("Post", models["Post"]) + relation = get_relation(models["Post"], "author") + resolver = self._make_resolver(post_type, relation) + + loader = Mock(spec=SQLAlchemyDataLoader) + loader.load = Mock(return_value=_awaitable(None)) + registry = Mock(spec=LoaderRegistry) + registry.get_loader.return_value = loader + + post = SimpleNamespace(author_id=7) + info = SimpleNamespace(context={"loader_registry": registry}) + await resolver(post, info) + + (key,) = loader.load.call_args.args + assert key == 7 # not (7,) + + async def test_passes_composite_key_as_tuple(self, composite_key_models): + """A composite-PK relationship must hand the loader the full tuple.""" + city_type = SQLAlchemyObjectType("City", composite_key_models["City"]) + relation = get_relation(composite_key_models["City"], "region") + resolver = city_type._create_relation_resolver(relation) + + loader = Mock(spec=SQLAlchemyDataLoader) + loader.load = Mock(return_value=_awaitable(None)) + registry = Mock(spec=LoaderRegistry) + registry.get_loader.return_value = loader + + city = SimpleNamespace(country="US", region_code="CA") + info = SimpleNamespace(context={"loader_registry": registry}) + await resolver(city, info) + + (key,) = loader.load.call_args.args + assert isinstance(key, tuple) and len(key) == 2 + assert sorted(key) == sorted(("US", "CA")) + + async def test_uses_overridden_loader_registry_lookup(self, models): + """`get_loader_registry_from_context` is the seam users override - + the resolver must go through it, not read `context["loader_registry"]` + directly.""" + + class MyObjectType(SQLAlchemyObjectType): + @staticmethod + def get_loader_registry_from_context(context): + return context["custom_loaders"] + + post_type = MyObjectType("Post", models["Post"]) + relation = get_relation(models["Post"], "author") + resolver = post_type._create_relation_resolver(relation) + + loader = Mock(spec=SQLAlchemyDataLoader) + loader.load = Mock(return_value=_awaitable("alice")) + registry = Mock(spec=LoaderRegistry) + registry.get_loader.return_value = loader + + post = SimpleNamespace(author_id=1) + info = SimpleNamespace(context={"custom_loaders": registry}) + + result = await resolver(post, info) + + assert result == "alice" + registry.get_loader.assert_called_once_with(relation) + + +def _awaitable(value): + async def _coro(): + return value + + return _coro() diff --git a/tests/sqlalchemy/test_query.py b/tests/sqlalchemy/test_query.py new file mode 100644 index 00000000..9a8f5d61 --- /dev/null +++ b/tests/sqlalchemy/test_query.py @@ -0,0 +1,464 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from graphql import ( + GraphQLField, + GraphQLInt, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, +) +from sqlalchemy.orm import joinedload, selectinload + +from ariadne import make_executable_schema +from ariadne.contrib.sqlalchemy import ( + SQLAlchemyObjectType, + SQLAlchemyQueryType, +) + +TYPE_DEFS = """ + type Query { + users: [User!]! + user(id: ID!): User + posts: [Post!]! + tags: [Tag!]! + ping: String + } + + type User { + id: ID! + username: String! + posts: [Post!]! + } + + type Post { + id: ID! + title: String! + author: User! + tags: [Tag!]! + } + + type Tag { + id: ID! + name: String! + posts: [Post!]! + } +""" + + +def _make_object_types(models): + return ( + SQLAlchemyObjectType("User", models["User"]), + SQLAlchemyObjectType("Post", models["Post"]), + SQLAlchemyObjectType("Tag", models["Tag"]), + ) + + +def _make_sync_session(scalar_first=None, scalar_all=None): + """Build a sync session whose `.execute(stmt)` chain returns the configured + scalars.first() and scalars().unique().all() values.""" + scalars = Mock() + scalars.first.return_value = scalar_first + unique = Mock() + unique.all.return_value = scalar_all if scalar_all is not None else [] + scalars.unique.return_value = unique + + result = Mock() + result.scalars.return_value = scalars + + session = Mock() + session.execute.return_value = result + return session + + +def _make_async_session(scalar_first=None, scalar_all=None): + """Async-style session whose `.execute(...)` returns an awaitable result.""" + scalars = Mock() + scalars.first.return_value = scalar_first + unique = Mock() + unique.all.return_value = scalar_all if scalar_all is not None else [] + scalars.unique.return_value = unique + + result = Mock() + result.scalars.return_value = scalars + + session = Mock() + session.execute = AsyncMock(return_value=result) + return session + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +class TestInit: + def test_inherits_query_name(self, models): + user_type, post_type, tag_type = _make_object_types(models) + query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + + assert query.name == "Query" + + def test_indexes_object_types_by_graphql_name(self, models): + user_type, post_type, tag_type = _make_object_types(models) + query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + + assert query.object_types == { + "User": user_type, + "Post": post_type, + "Tag": tag_type, + } + + def test_indexes_object_types_by_model_class(self, models): + user_type, post_type, tag_type = _make_object_types(models) + query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + + assert query._object_types_by_model == { + models["User"]: user_type, + models["Post"]: post_type, + models["Tag"]: tag_type, + } + + def test_accepts_empty_sequence(self): + query = SQLAlchemyQueryType([]) + + assert query.object_types == {} + assert query._object_types_by_model == {} + + +# --------------------------------------------------------------------------- +# get_session_from_context +# --------------------------------------------------------------------------- + + +class TestGetSessionFromContext: + def test_returns_value_from_default_key(self): + session = Mock(name="session") + + assert ( + SQLAlchemyQueryType.get_session_from_context({"session": session}) + is session + ) + + def test_missing_key_raises_runtime_error(self): + with pytest.raises(RuntimeError, match="session"): + SQLAlchemyQueryType.get_session_from_context({}) + + def test_subclass_can_override_lookup(self): + class MyQueryType(SQLAlchemyQueryType): + @staticmethod + def get_session_from_context(context): + return context["request"].state.session + + session = Mock(name="session") + context = {"request": SimpleNamespace(state=SimpleNamespace(session=session))} + + assert MyQueryType.get_session_from_context(context) is session + + +# --------------------------------------------------------------------------- +# bind_to_schema +# --------------------------------------------------------------------------- + + +class TestBindToSchema: + def test_binds_auto_resolvers_to_known_types(self, models): + user_type, post_type, tag_type = _make_object_types(models) + query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + + schema = make_executable_schema( + TYPE_DEFS, [query, user_type, post_type, tag_type] + ) + + query_obj = schema.type_map["Query"] + # Every field whose type is one of our SQLAlchemyObjectTypes must have + # an auto-resolver wired up. + assert query_obj.fields["users"].resolve is not None + assert query_obj.fields["posts"].resolve is not None + assert query_obj.fields["tags"].resolve is not None + assert query_obj.fields["user"].resolve is not None + + # The resolvers also live in the bindable's `_resolvers` map. + assert "users" in query._resolvers + assert "posts" in query._resolvers + assert "tags" in query._resolvers + assert "user" in query._resolvers + + def test_does_not_bind_auto_resolver_for_unknown_type(self, models): + # `ping: String` returns a scalar with no matching SQLAlchemyObjectType. + user_type, post_type, tag_type = _make_object_types(models) + query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + + make_executable_schema(TYPE_DEFS, [query, user_type, post_type, tag_type]) + + assert "ping" not in query._resolvers + + def test_does_not_overwrite_existing_resolver(self, models): + user_type, post_type, tag_type = _make_object_types(models) + query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + + @query.field("users") + def custom_users(*_): # pragma: no cover - identity check only + return [] + + make_executable_schema(TYPE_DEFS, [query, user_type, post_type, tag_type]) + + assert query._resolvers["users"] is custom_users + + def test_unwraps_list_and_nonnull_wrappers(self, models): + """`[User!]!` and `User!` and bare `User` must all resolve to the same + target name `User`. The list wrapper toggles `return_list`.""" + user_type, post_type, tag_type = _make_object_types(models) + query = SQLAlchemyQueryType([user_type, post_type, tag_type]) + + make_executable_schema(TYPE_DEFS, [query, user_type, post_type, tag_type]) + + # both list and scalar fields had resolvers attached + assert "users" in query._resolvers # [User!]! + assert "user" in query._resolvers # User (nullable scalar) + + def test_falls_back_to_super_when_query_type_missing(self, models): + """If the schema has no `Query` GraphQL type, `bind_to_schema` should + defer to the parent class (which raises ValueError). This is the + fallback branch that protects from running the auto-binding loop on + something that isn't a GraphQLObjectType.""" + user_type, _, _ = _make_object_types(models) + query = SQLAlchemyQueryType([user_type]) + + # Schema without a "Query" type at all. + schema = GraphQLSchema( + types=[ + GraphQLObjectType( + "Other", {"x": GraphQLField(GraphQLNonNull(GraphQLInt))} + ) + ] + ) + + with pytest.raises(ValueError, match="Query"): + query.bind_to_schema(schema) + + +# --------------------------------------------------------------------------- +# _create_auto_resolver +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCreateAutoResolver: + async def test_calls_get_session_from_context(self, models): + user_type, _, _ = _make_object_types(models) + query = SQLAlchemyQueryType([user_type]) + + session = _make_sync_session(scalar_all=[]) + info = SimpleNamespace( + context={"session": session}, + field_nodes=[], + fragments={}, + schema=None, + ) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ): + resolver = query._create_auto_resolver(user_type, return_list=True) + await resolver(None, info) + + session.execute.assert_called_once() + + async def test_returns_list_when_return_list_true(self, models): + user_type, _, _ = _make_object_types(models) + query = SQLAlchemyQueryType([user_type]) + + rows = [Mock(name="user1"), Mock(name="user2")] + session = _make_sync_session(scalar_all=rows) + info = SimpleNamespace(context={"session": session}) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ): + resolver = query._create_auto_resolver(user_type, return_list=True) + result = await resolver(None, info) + + assert result == rows + session.execute.return_value.scalars.return_value.unique.assert_called_once() + + async def test_returns_first_when_return_list_false(self, models): + user_type, _, _ = _make_object_types(models) + query = SQLAlchemyQueryType([user_type]) + + sentinel = Mock(name="single_user") + session = _make_sync_session(scalar_first=sentinel) + info = SimpleNamespace(context={"session": session}) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ): + resolver = query._create_auto_resolver(user_type, return_list=False) + result = await resolver(None, info) + + assert result is sentinel + session.execute.return_value.scalars.return_value.first.assert_called_once() + + async def test_awaits_async_session_execute(self, models): + user_type, _, _ = _make_object_types(models) + query = SQLAlchemyQueryType([user_type]) + + rows = [Mock(name="user1")] + session = _make_async_session(scalar_all=rows) + info = SimpleNamespace(context={"session": session}) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ): + resolver = query._create_auto_resolver(user_type, return_list=True) + result = await resolver(None, info) + + assert result == rows + session.execute.assert_awaited_once() + + async def test_passes_object_type_config_to_auto_eager_load(self, models): + user_type = SQLAlchemyObjectType("User", models["User"]) + post_type = SQLAlchemyObjectType( + "Post", + models["Post"], + aliases={"my_id": "id"}, + strategies={"author": joinedload, "tags": selectinload}, + max_depth=5, + ) + query = SQLAlchemyQueryType([user_type, post_type]) + + session = _make_sync_session(scalar_all=[]) + info = SimpleNamespace(context={"session": session}) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ) as eager_mock: + resolver = query._create_auto_resolver(post_type, return_list=True) + await resolver(None, info) + + # auto_eager_load called with the model + per-type config + the + # query's model→type registry so nested types can be resolved. + _stmt, passed_info, passed_model = eager_mock.call_args.args + kwargs = eager_mock.call_args.kwargs + assert passed_info is info + assert passed_model is models["Post"] + assert kwargs["strategies"] == {"author": joinedload, "tags": selectinload} + assert kwargs["aliases"] == {"my_id": "id"} + assert kwargs["max_depth"] == 5 + assert kwargs["type_registry"] == { + models["User"]: user_type, + models["Post"]: post_type, + } + + async def test_applies_where_filter_for_known_kwargs(self, models): + user_type, _, _ = _make_object_types(models) + query = SQLAlchemyQueryType([user_type]) + + session = _make_sync_session(scalar_first=None) + info = SimpleNamespace(context={"session": session}) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ): + resolver = query._create_auto_resolver(user_type, return_list=False) + await resolver(None, info, id=42) + + # The constructed statement passed to execute should include a WHERE. + executed_stmt = session.execute.call_args.args[0] + sql = str(executed_stmt.compile()).lower() + assert "where" in sql and "users.id" in sql + + async def test_resolves_kwarg_through_aliases(self, models): + # `my_id` GraphQL arg must be translated to `id` column on the model. + post_type = SQLAlchemyObjectType( + "Post", models["Post"], aliases={"my_id": "id"} + ) + query = SQLAlchemyQueryType([post_type]) + + session = _make_sync_session(scalar_first=None) + info = SimpleNamespace(context={"session": session}) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ): + resolver = query._create_auto_resolver(post_type, return_list=False) + await resolver(None, info, my_id=7) + + executed_stmt = session.execute.call_args.args[0] + sql = str(executed_stmt.compile()).lower() + assert "posts.id" in sql + + async def test_unknown_kwargs_are_ignored(self, models): + user_type, _, _ = _make_object_types(models) + query = SQLAlchemyQueryType([user_type]) + + session = _make_sync_session(scalar_all=[]) + info = SimpleNamespace(context={"session": session}) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ): + resolver = query._create_auto_resolver(user_type, return_list=True) + await resolver(None, info, nonexistent_field="ignore-me") + + executed_stmt = session.execute.call_args.args[0] + sql = str(executed_stmt.compile()).lower() + assert "where" not in sql + + async def test_uses_overridden_session_lookup(self, models): + class MyQueryType(SQLAlchemyQueryType): + @staticmethod + def get_session_from_context(context): + return context["custom_session"] + + user_type, _, _ = _make_object_types(models) + query = MyQueryType([user_type]) + + session = _make_sync_session(scalar_all=[]) + info = SimpleNamespace(context={"custom_session": session}) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ): + resolver = query._create_auto_resolver(user_type, return_list=True) + await resolver(None, info) + + session.execute.assert_called_once() + + async def test_uses_object_types_get_base_query(self, models): + """Subclasses may override `get_base_query` to apply default filters - + the auto-resolver must route through it instead of building its own + `select(model)`.""" + + class FilteredPost(SQLAlchemyObjectType): + def get_base_query(self, info, **kwargs): + from sqlalchemy import select + + return select(self.model).where(self.model.title == "fixed") + + post_type = FilteredPost("Post", models["Post"]) + query = SQLAlchemyQueryType([post_type]) + + session = _make_sync_session(scalar_all=[]) + info = SimpleNamespace(context={"session": session}) + + with patch( + "ariadne.contrib.sqlalchemy.query.auto_eager_load", + side_effect=lambda stmt, *_a, **_kw: stmt, + ): + resolver = query._create_auto_resolver(post_type, return_list=True) + await resolver(None, info) + + executed_stmt = session.execute.call_args.args[0] + sql = str(executed_stmt.compile()).lower() + assert "posts.title" in sql and "where" in sql diff --git a/tests/sqlalchemy/test_utils.py b/tests/sqlalchemy/test_utils.py new file mode 100644 index 00000000..77213477 --- /dev/null +++ b/tests/sqlalchemy/test_utils.py @@ -0,0 +1,404 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest +from graphql import GraphQLError, parse +from sqlalchemy.orm import ( + class_mapper, + joinedload, + lazyload, + selectinload, + subqueryload, +) + +from ariadne.contrib.sqlalchemy import SQLAlchemyObjectType +from ariadne.contrib.sqlalchemy.utils import ( + _build_options, + _resolve_load_option, + auto_eager_load, +) + + +def _info_for(query_string: str, root_field: str) -> SimpleNamespace: + """Mimic the `info` object passed to a resolver for `root_field`. + + Picks the matching FieldNodes out of the parsed operation's selection set, + matching what graphql-core hands the resolver at runtime. + """ + document = parse(query_string) + operation = document.definitions[0] + field_nodes = [ + node + for node in operation.selection_set.selections + if node.name.value == root_field + ] + return SimpleNamespace(field_nodes=field_nodes) + + +def _selections(query_string: str, root_field: str): + """Return the inner selection set's selections for `root_field`.""" + info = _info_for(query_string, root_field) + return info.field_nodes[0].selection_set.selections + + +# --------------------------------------------------------------------------- +# _resolve_load_option +# --------------------------------------------------------------------------- + + +class TestResolveLoadOption: + def test_defaults_to_selectinload_for_collections(self, models): + mapper = class_mapper(models["Post"]) + rel = mapper.relationships["tags"] + + opt = _resolve_load_option(mapper, "tags", None, rel, load_path=None) + + assert "Post.tags" in str(opt.path) + # selectinload emits a separate IN query; joinedload would have rolled + # the relationship into the parent path in a single statement. We can + # tell them apart by patching and re-running. + with patch( + "ariadne.contrib.sqlalchemy.utils.selectinload", + wraps=selectinload, + ) as sel: + _resolve_load_option(mapper, "tags", None, rel, load_path=None) + sel.assert_called_once() + + def test_defaults_to_joinedload_for_scalars(self, models): + mapper = class_mapper(models["Post"]) + rel = mapper.relationships["author"] + + with patch( + "ariadne.contrib.sqlalchemy.utils.joinedload", + wraps=joinedload, + ) as joined: + _resolve_load_option(mapper, "author", None, rel, load_path=None) + joined.assert_called_once() + + def test_uses_explicit_strategy(self, models): + mapper = class_mapper(models["Post"]) + rel = mapper.relationships["tags"] + strategy = Mock(side_effect=lambda attr: selectinload(attr)) + strategy.__name__ = "selectinload" + + _resolve_load_option(mapper, "tags", strategy, rel, load_path=None) + + strategy.assert_called_once() + # The single positional arg should be the InstrumentedAttribute + passed = strategy.call_args.args[0] + assert passed is models["Post"].tags + + def test_chains_onto_load_path(self, models): + mapper = class_mapper(models["Post"]) + rel_tags = mapper.relationships["tags"] + + root = _resolve_load_option(mapper, "tags", selectinload, rel_tags, None) + + tag_mapper = class_mapper(models["Tag"]) + rel_posts = tag_mapper.relationships["posts"] + nested = _resolve_load_option( + tag_mapper, "posts", selectinload, rel_posts, load_path=root + ) + + # Nested option's path includes both legs of the relationship chain. + path_str = str(nested.path) + assert "Post.tags" in path_str + assert "Tag.posts" in path_str + + +# --------------------------------------------------------------------------- +# auto_eager_load +# --------------------------------------------------------------------------- + + +class TestAutoEagerLoad: + def test_returns_query_unchanged_when_no_field_nodes(self, models): + query = Mock(name="query") + info = SimpleNamespace(field_nodes=[]) + + result = auto_eager_load(query, info, models["Post"]) + + assert result is query + query.options.assert_not_called() + + def test_returns_query_unchanged_when_root_field_has_no_selections(self, models): + """A scalar root field has no selection set, so there is nothing to + eager-load - the query must come back untouched.""" + query = Mock(name="query") + # `ping` would be a scalar root field with no inner selection set. + document = parse("query Q { ping }") + operation = document.definitions[0] + info = SimpleNamespace(field_nodes=list(operation.selection_set.selections)) + + result = auto_eager_load(query, info, models["Post"]) + + assert result is query + query.options.assert_not_called() + + def test_passes_options_to_query(self, models): + query = Mock(name="query") + info = _info_for("query Q { posts { title } }", "posts") + + auto_eager_load(query, info, models["Post"]) + + query.options.assert_called_once() + # At minimum, scalar `title` should produce one load_only option. + assert len(query.options.call_args.args) >= 1 + + def test_emits_load_only_for_scalar_fields(self, models): + query = Mock(name="query") + info = _info_for("query Q { posts { id title } }", "posts") + + with patch( + "ariadne.contrib.sqlalchemy.utils.load_only", wraps=lambda *a: a + ) as lo: + auto_eager_load(query, info, models["Post"]) + + lo.assert_called_once() + loaded_attrs = {attr.key for attr in lo.call_args.args} + # Includes the requested scalars plus the FK column needed by any + # relationship resolver fall-back. + assert {"id", "title"}.issubset(loaded_attrs) + assert "author_id" in loaded_attrs + + def test_includes_fk_columns_in_load_only(self, models): + """Even if the GraphQL selection only asks for a single scalar, the + FK columns of every relationship on that mapper must be loaded so the + DataLoader fallback can fetch related rows by FK.""" + query = Mock(name="query") + info = _info_for("query Q { posts { title } }", "posts") + + with patch( + "ariadne.contrib.sqlalchemy.utils.load_only", wraps=lambda *a: a + ) as lo: + auto_eager_load(query, info, models["Post"]) + + loaded = {attr.key for attr in lo.call_args.args} + assert "author_id" in loaded + + def test_no_load_only_when_only_relationships_selected(self, models): + """If the selection contains only relationships (no scalar columns), + nothing is loaded via `load_only` - we just attach the relationship + loaders. This means FKs are not pre-loaded either.""" + query = Mock(name="query") + info = _info_for("query Q { posts { tags { id } } }", "posts") + + with patch( + "ariadne.contrib.sqlalchemy.utils.load_only", wraps=lambda *a: a + ) as lo: + auto_eager_load(query, info, models["Post"]) + + # Only the inner Tag's `id` triggers load_only - never the outer Post. + for call in lo.call_args_list: + keys = {attr.key for attr in call.args} + assert "title" not in keys + + def test_default_strategy_for_collection_relationship(self, models): + query = Mock(name="query") + info = _info_for("query Q { posts { tags { id } } }", "posts") + + with patch( + "ariadne.contrib.sqlalchemy.utils.selectinload", + wraps=selectinload, + ) as sel: + auto_eager_load(query, info, models["Post"]) + + # `tags` is a collection -> selectinload by default + sel.assert_called() + + def test_default_strategy_for_scalar_relationship(self, models): + query = Mock(name="query") + info = _info_for("query Q { posts { author { id } } }", "posts") + + with patch( + "ariadne.contrib.sqlalchemy.utils.joinedload", + wraps=joinedload, + ) as joined: + auto_eager_load(query, info, models["Post"]) + + # `author` is a many-to-one scalar -> joinedload by default + joined.assert_called() + + def test_explicit_strategy_overrides_default(self, models): + query = Mock(name="query") + info = _info_for("query Q { posts { tags { id } } }", "posts") + strategy = Mock(side_effect=subqueryload) + strategy.__name__ = "subqueryload" + + auto_eager_load(query, info, models["Post"], strategies={"tags": strategy}) + + strategy.assert_called_once() + + def test_aliases_translate_graphql_field_to_db_attr(self, models): + query = Mock(name="query") + # GraphQL exposes `myTitle`; map it to the `title` column on Post. + info = _info_for("query Q { posts { myTitle } }", "posts") + + with patch( + "ariadne.contrib.sqlalchemy.utils.load_only", wraps=lambda *a: a + ) as lo: + auto_eager_load(query, info, models["Post"], aliases={"myTitle": "title"}) + + loaded = {attr.key for attr in lo.call_args.args} + assert "title" in loaded + + def test_recurses_into_nested_selections(self, models): + query = Mock(name="query") + info = _info_for("query Q { posts { tags { name posts { id } } } }", "posts") + + auto_eager_load(query, info, models["Post"]) + + opts = query.options.call_args.args + joined = " | ".join(str(o.path) for o in opts) + assert "Post.tags" in joined + assert "Tag.posts" in joined + + def test_aliases_only_apply_at_root_without_registry(self, models): + """`auto_eager_load`'s `aliases` argument applies to the root model + only. Nested types fall back to no aliases unless a `type_registry` + provides per-type config.""" + query = Mock(name="query") + # Root selection uses `myTitle` (alias) and a relationship; the nested + # tag selection uses unaliased `name`. + info = _info_for("query Q { posts { myTitle tags { name } } }", "posts") + + auto_eager_load(query, info, models["Post"], aliases={"myTitle": "title"}) + + query.options.assert_called_once() + + def test_uses_type_registry_for_nested_type_aliases(self, models): + """A type_registry entry for the nested type supplies that type's + aliases - the parent's aliases do not leak into the recursion.""" + query = Mock(name="query") + info = _info_for("query Q { posts { tags { my_name } } }", "posts") + + post_ot = SQLAlchemyObjectType("Post", models["Post"]) + tag_ot = SQLAlchemyObjectType("Tag", models["Tag"], aliases={"my_name": "name"}) + registry = {models["Post"]: post_ot, models["Tag"]: tag_ot} + + from ariadne.contrib.sqlalchemy.utils import _build_options + + with patch( + "ariadne.contrib.sqlalchemy.utils._build_options", + wraps=_build_options, + ) as build_mock: + auto_eager_load(query, info, models["Post"], type_registry=registry) + + # Find the recursive call entered for the Tag mapper. + tag_calls = [ + c for c in build_mock.call_args_list if c.args[0].class_ is models["Tag"] + ] + assert tag_calls, "expected a recursive call for the Tag mapper" + # Aliases dict at the Tag level is sourced from the Tag's own config. + passed_aliases = tag_calls[0].args[3] + assert passed_aliases == {"my_name": "name"} + + def test_uses_child_strategy_from_registry(self, models): + """A nested type's `strategies` dict is what controls how its own + relationships load. The strategy's `__name__` is what matters at + nested levels - the option is chained via `load_path.(attr)`.""" + query = Mock(name="query") + info = _info_for("query Q { posts { tags { posts { id } } } }", "posts") + + post_ot = SQLAlchemyObjectType("Post", models["Post"]) + # Use a real loader so chaining via `load_path.subqueryload` works. + tag_ot = SQLAlchemyObjectType( + "Tag", models["Tag"], strategies={"posts": subqueryload} + ) + registry = {models["Post"]: post_ot, models["Tag"]: tag_ot} + + with patch( + "ariadne.contrib.sqlalchemy.utils._resolve_load_option", + wraps=_resolve_load_option, + ) as resolve_mock: + auto_eager_load(query, info, models["Post"], type_registry=registry) + + # The Tag-level resolution must receive the Tag's strategy (`subqueryload`), + # rather than falling back to the default `selectinload` for collections. + tag_calls = [c for c in resolve_mock.call_args_list if c.args[1] == "posts"] + assert tag_calls, "expected a resolve call for Tag.posts" + assert tag_calls[0].args[2] is subqueryload + + def test_max_depth_from_registry_raises_graphql_error(self, models): + """A type's `max_depth` is sourced from its `SQLAlchemyObjectType` in + the registry. Re-entering the same type beyond that count raises.""" + query = Mock(name="query") + info = _info_for( + # Post -> tags -> posts -> tags : Post is entered twice + "query Q { posts { tags { posts { tags { id } } } } }", + "posts", + ) + + post_ot = SQLAlchemyObjectType("Post", models["Post"], max_depth=1) + tag_ot = SQLAlchemyObjectType("Tag", models["Tag"], max_depth=1) + registry = {models["Post"]: post_ot, models["Tag"]: tag_ot} + + with pytest.raises(GraphQLError, match="max_depth"): + auto_eager_load(query, info, models["Post"], type_registry=registry) + + def test_default_max_depth_three_allows_three_levels(self, models): + """With the default max_depth=3, three levels of the same type should + load without raising.""" + query = Mock(name="query") + info = _info_for( + "query Q { posts { tags { posts { tags { id } } } } }", + "posts", + ) + + # No registry: every type defaults to max_depth=3. + auto_eager_load(query, info, models["Post"]) + + query.options.assert_called_once() + + def test_max_depth_error_includes_type_name_and_depth(self, models): + query = Mock(name="query") + info = _info_for("query Q { posts { tags { posts { id } } } }", "posts") + + post_ot = SQLAlchemyObjectType("Post", models["Post"], max_depth=1) + registry = {models["Post"]: post_ot} + + with pytest.raises(GraphQLError) as exc_info: + auto_eager_load(query, info, models["Post"], type_registry=registry) + + message = str(exc_info.value) + assert "Post" in message + assert "max_depth=1" in message + + def test_unknown_graphql_field_is_ignored(self, models): + """Selecting a GraphQL field that doesn't exist on the SQLAlchemy + model (no column, no relationship, no alias) just gets skipped.""" + query = Mock(name="query") + info = _info_for("query Q { posts { id ghostField } }", "posts") + + auto_eager_load(query, info, models["Post"]) + + query.options.assert_called_once() + + +# --------------------------------------------------------------------------- +# _build_options (exercised directly for finer-grained branch coverage) +# --------------------------------------------------------------------------- + + +class TestBuildOptions: + def test_returns_empty_list_for_empty_selection(self, models): + mapper = class_mapper(models["Post"]) + opts = _build_options(mapper, [], {}, {}, {models["Post"]: 1}) + assert opts == [] + + def test_supports_lazyload_strategy_via_dunder_name(self, models): + """The chaining branch (`getattr(strategy, '__name__')`) requires the + strategy to expose `__name__`. The shipped SQLAlchemy strategies all + do - sanity-check with `lazyload`.""" + mapper = class_mapper(models["Post"]) + selections = _selections("query Q { posts { tags { posts { id } } } }", "posts") + + opts = _build_options( + mapper, + selections, + {"tags": lazyload, "posts": lazyload}, + {}, + {models["Post"]: 1}, + ) + + assert any("Post.tags" in str(o.path) for o in opts)