Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ hatch test benchmark --benchmark-storage=file://benchmark/results
just coverage
```

### Running Examples

Use `uv` to run examples from the `examples/` directory. Refer to the docstrings within each example file for specific commands.
Every example should include an example command of running a particular example with uvicorn.
```bash
# Run a basic query example
uv run examples/basic_query_example.py

# Run an ASGI example with uvicorn
uv run --with "uvicorn[standard]" --with ariadne \
uvicorn examples.basic_query_example:app --reload
```

## Code Style Requirements

- **Python 3.10+** with type hints throughout
Expand Down Expand Up @@ -96,4 +109,6 @@ Follow [Conventional Commits](https://www.conventionalcommits.org/):
2. Ensure test coverage meets the 90% minimum requirement
3. Format code with `just fmt`
4. Verify type hints with `just types`
5. Write a clear commit message following the conventional commits format
5. Ensure the documentation is up-to-date
6. Write a clear commit message following the conventional commits format

6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ All notable unreleased changes to this project will be documented in this file.

For released versions, see the [Releases](https://github.com/mirumee/ariadne/releases) page.

## Unreleased

## 1.1.0a3 (2026-05-06)

=======
### ✨ New Features
- Automated SQLAlchemy integration
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Documentation is available [here](https://ariadnegraphql.org).
- Loading schema from `.graphql`, `.gql`, and `.graphqls` files.
- ASGI and WSGI support, with integrations for Django, FastAPI, Flask, and Starlette.
- Opt-in automatic resolvers mapping between `camelCase123` and `snake_case_123`.
- Automated integration with **SQLAlchemy 2.0** for zero-boilerplate resolvers and N+1 prevention.
- [OpenTelemetry](https://opentelemetry.io/) extension for API monitoring.
- Built-in [GraphiQL](https://github.com/graphql/graphiql) explorer for development and testing.
- GraphQL syntax validation via `gql()` helper function.
Expand Down
22 changes: 22 additions & 0 deletions ariadne/contrib/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
try:
from .dataloaders import LoaderRegistry, SQLAlchemyDataLoader
from .extension import SQLAlchemyDataLoaderExtension
from .objects import SQLAlchemyObjectType
from .query import SQLAlchemyQueryType
from .types import LoadStrategy
from .utils import auto_eager_load
except ImportError as ex:
raise ImportError(
"SQLAlchemy integration requires the 'sqlalchemy' and 'aiodataloader' "
"packages. Install them using 'pip install \"ariadne[sqlalchemy]\"'."
) from ex

__all__ = [
"SQLAlchemyDataLoaderExtension",
"LoadStrategy",
"LoaderRegistry",
"SQLAlchemyObjectType",
"SQLAlchemyQueryType",
"SQLAlchemyDataLoader",
"auto_eager_load",
]
124 changes: 124 additions & 0 deletions ariadne/contrib/sqlalchemy/dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import inspect
import logging
from collections import defaultdict
from typing import Any

from aiodataloader import DataLoader
from sqlalchemy import select, tuple_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import RelationshipProperty, Session

logger = logging.getLogger(__name__)


class SQLAlchemyDataLoader(DataLoader):
"""
DataLoader for SQLAlchemy relationships supporting:
- Composite Keys
- Many-to-Many (secondary tables)
- Result grouping via SQL columns (optimized)
"""

def __init__(
self,
session: Session | AsyncSession,
relation_prop: RelationshipProperty,
cache: bool = True,
):
super().__init__(cache=cache)
self.session = session
self.relation_prop = relation_prop
self.target_model = relation_prop.mapper.class_
self.is_list = relation_prop.uselist

# Identify local and remote columns (handles composite keys)
if relation_prop.secondary is not None:
self.local_cols = [
lp.key
for lp, rp in relation_prop.synchronize_pairs
if lp.key is not None
]
self.remote_cols = [
rp.key
for lp, rp in relation_prop.synchronize_pairs
if rp.key is not None
]
else:
self.local_cols = [
c.key for c in relation_prop.local_columns if c.key is not None
]
self.remote_cols = [
c.key for c in relation_prop.remote_side if c.key is not None
]

self.secondary = relation_prop.secondary

def get_query(self, keys: list[Any]):
"""Builds query. Handles composite IN clause and M2M joins."""
target_model = self.target_model
stmt = select(target_model)

if self.secondary is not None:
stmt = stmt.join(self.secondary)
filter_cols = [self.secondary.c[k] for k in self.remote_cols]
else:
filter_cols = [getattr(target_model, k) for k in self.remote_cols]

# Add the filtering columns to the result to allow grouping
stmt = stmt.add_columns(*filter_cols)

if len(filter_cols) > 1:
stmt = stmt.where(tuple_(*filter_cols).in_(keys))
else:
# Flatten keys if they are single-element tuples
flat_keys = [k[0] if isinstance(k, (list, tuple)) else k for k in keys]
stmt = stmt.where(filter_cols[0].in_(flat_keys))

return stmt

async def batch_load_fn(self, keys: list[Any]) -> list[Any]:
logger.debug(
"SQLAlchemyRelationLoader: Fetching %s for %d parents",
self.target_model.__name__,
len(keys),
)
stmt = self.get_query(keys)

result = self.session.execute(stmt)
if inspect.isawaitable(result):
result = await result

rows = result.all() # type: ignore

num_filter_cols = len(self.remote_cols)
grouped = defaultdict(list)

for row in rows:
item = row[0]
# The filter columns are appended after the model instance
key_parts = row[1 : 1 + num_filter_cols]
key = tuple(key_parts) if num_filter_cols > 1 else key_parts[0]
grouped[key].append(item)

return [
grouped[k] if self.is_list else (grouped[k][0] if grouped[k] else None)
for k in keys
]


class LoaderRegistry:
def __init__(self, session: Session | AsyncSession):
self.session = session
self._loaders: dict[
tuple[RelationshipProperty, type[DataLoader]], DataLoader
] = {}

def get_loader(
self,
relation_prop: RelationshipProperty,
loader_class: type[SQLAlchemyDataLoader] = SQLAlchemyDataLoader,
) -> DataLoader:
key = (relation_prop, loader_class)
if key not in self._loaders:
self._loaders[key] = loader_class(self.session, relation_prop)
return self._loaders[key]
26 changes: 26 additions & 0 deletions ariadne/contrib/sqlalchemy/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Any

from ...types import Extension
from .dataloaders import LoaderRegistry


class SQLAlchemyDataLoaderExtension(Extension):
"""Ariadne extension that creates a per-request `LoaderRegistry`.

Wires the SQLAlchemy DataLoader fallback path automatically: at the start
of each GraphQL request, reads the session from `context[session_key]`
and writes a fresh `LoaderRegistry(session)` to `context[registry_key]`.

"""

def __init__(
self,
*,
session_key: str = "session",
registry_key: str = "loader_registry",
):
self.session_key = session_key
self.registry_key = registry_key

def request_started(self, context: Any) -> None:
context[self.registry_key] = LoaderRegistry(context[self.session_key])
124 changes: 124 additions & 0 deletions ariadne/contrib/sqlalchemy/objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

from collections.abc import Callable
from typing import Any, cast

from graphql import GraphQLObjectType, GraphQLSchema
from sqlalchemy import select
from sqlalchemy.orm import DeclarativeBase, RelationshipProperty, class_mapper

from ...objects import ObjectType
from .dataloaders import LoaderRegistry
from .types import LoadStrategy


class SQLAlchemyObjectType(ObjectType):
"""
ObjectType specialized for SQLAlchemy models.
Automatically binds resolvers for relationships using DataLoaders.
"""

model: type[DeclarativeBase]
aliases: dict[str, str]
strategies: dict[str, LoadStrategy]
max_depth: int
_registry_key: str

def __init__(
self,
name: str,
model: type[DeclarativeBase],
*,
aliases: dict[str, str] | Callable[[], dict[str, str]] | None = None,
strategies: dict[str, LoadStrategy] | None = None,
max_depth: int = 3,
):
super().__init__(name)
self.model = model
self.aliases = aliases() if callable(aliases) else (aliases or {}) # ty: ignore[call-top-callable]
self.strategies = strategies or {}
self.max_depth = max_depth

def bind_to_schema(self, schema: GraphQLSchema) -> None:
"""Binds this `SQLAlchemyObjectType` to the GraphQL schema.

Auto-generates resolvers for the model's relationships and aliased
columns, then delegates to `ObjectType.bind_to_schema` to wire them
(along with any explicitly-set resolvers) onto the schema's fields.

The auto-resolvers must be registered before calling `super()` so
they are included when the parent iterates `self._resolvers` to
populate the GraphQL type's field `resolve` attributes.
"""
graphql_type = schema.type_map.get(self.name)
self.validate_graphql_type(graphql_type)
self._bind_auto_resolvers(cast(GraphQLObjectType, graphql_type))
super().bind_to_schema(schema)

def get_base_query(self, info: Any, **kwargs: Any):
"""
Returns the base SQLAlchemy select statement for root queries.
Can be overridden to apply default filters.
"""
return select(self.model)

def _bind_auto_resolvers(self, graphql_type: GraphQLObjectType) -> None:
schema_fields = graphql_type.fields
mapper = class_mapper(self.model)

for gql_field, db_attr in self.aliases.items():
if gql_field not in schema_fields:
continue
if callable(db_attr):
self.set_field(gql_field, db_attr)
else:
self.set_field(
gql_field, lambda obj, *_, _attr=db_attr: getattr(obj, _attr)
)

for relation in mapper.relationships:
if relation.key not in schema_fields:
continue
if relation.key in self._resolvers:
continue
self.set_field(relation.key, self._create_relation_resolver(relation))

@staticmethod
def get_loader_registry_from_context(context: Any) -> LoaderRegistry:
"""Get the `LoaderRegistry` from the GraphQL context.

Override this method to customize how the registry is retrieved.
"""
try:
return context["loader_registry"]
except KeyError:
raise RuntimeError(
"LoaderRegistry not found in context under key 'loader_registry'"
)

def _create_relation_resolver(self, relation: RelationshipProperty):
async def resolve(obj: Any, info: Any, **kwargs: Any):
# If the attribute is already loaded (e.g. via joinedload/selectinload),
# return it
if relation.key in obj.__dict__:
return getattr(obj, relation.key)

loader_registry = self.get_loader_registry_from_context(info.context)

# Identify which column(s) on the current object connect it to the
# target table. For a One-to-Many, this is usually a Foreign Key.
local_relation_columns = [
c.key for c in relation.local_columns if c.key is not None
]

# Extract the actual database values from this specific object instance.
join_values = tuple(getattr(obj, col) for col in local_relation_columns)

# If it's a standard single-column relationship, unwrap the tuple to just
# the ID. If it's a composite key, keep the tuple.
lookup_key = join_values[0] if len(join_values) == 1 else join_values

loader = loader_registry.get_loader(relation)
return await loader.load(lookup_key)

return resolve
Loading
Loading