Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions litestar/di.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from inspect import isasyncgenfunction, isclass, isgeneratorfunction
from inspect import isasyncgenfunction, isclass, isfunction, isgeneratorfunction
from typing import TYPE_CHECKING, Any

from litestar._signature import SignatureModel
Expand Down Expand Up @@ -58,12 +58,15 @@ def __init__(
raise ImproperlyConfiguredException("Provider dependency must be a callable value")

is_class_dependency = isclass(dependency)
self.has_sync_generator_dependency = isgeneratorfunction(
dependency if not is_class_dependency else dependency.__call__ # type: ignore[operator]
)
self.has_async_generator_dependency = isasyncgenfunction(
dependency if not is_class_dependency else dependency.__call__ # type: ignore[operator]
)
is_function = isfunction(dependency)
is_callable_instance = not is_class_dependency and not is_function
if is_class_dependency or is_callable_instance:
check_target = dependency.__call__ # type: ignore[operator]
else:
check_target = dependency
self.has_sync_generator_dependency = isgeneratorfunction(check_target)
self.has_async_generator_dependency = isasyncgenfunction(check_target)

has_generator_dependency = self.has_sync_generator_dependency or self.has_async_generator_dependency

if has_generator_dependency and use_cache:
Expand Down
135 changes: 135 additions & 0 deletions tests/unit/test_di.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,36 @@ async def async_generator_func() -> AsyncGenerator[float, None]:
yield 0.1


class SyncGeneratorCallable:
def __init__(self) -> None:
self.call_count = 0
self.cleanup_count = 0

def __call__(self) -> Generator[int, None]:
self.call_count += 1

try:
yield self.call_count
finally:
# Cleanup: remove the session
self.cleanup_count += 1


class AsyncGeneratorCallable:
def __init__(self) -> None:
self.call_count = 0
self.cleanup_count = 0

async def __call__(self) -> AsyncGenerator[int, None]:
self.call_count += 1

try:
yield self.call_count
finally:
# Cleanup: remove the session
self.cleanup_count += 1


async def async_callable(val: str = "three-one") -> str:
return val

Expand Down Expand Up @@ -145,12 +175,39 @@ def func() -> Generator[int, None, None]:
(async_callable, False),
(generator_func, True),
(async_generator_func, True),
(SyncGeneratorCallable(), True),
],
)
def test_dependency_has_async_callable(dep: Any, exp: bool) -> None:
assert Provide(dep).has_sync_callable is exp


@pytest.mark.parametrize(
("dep", "exp"),
[
(generator_func, True),
(async_generator_func, False),
(SyncGeneratorCallable(), True),
(AsyncGeneratorCallable(), False),
],
)
def test_dependency_has_sync_generator(dep: Any, exp: bool) -> None:
assert Provide(dep).has_sync_generator_dependency is exp


@pytest.mark.parametrize(
("dep", "exp"),
[
(generator_func, False),
(async_generator_func, True),
(SyncGeneratorCallable(), False),
(AsyncGeneratorCallable(), True),
],
)
def test_dependency_has_async_generator(dep: Any, exp: bool) -> None:
assert Provide(dep).has_async_generator_dependency is exp


def test_raises_when_dependency_is_not_callable() -> None:
with pytest.raises(ImproperlyConfiguredException):
Provide(123) # type: ignore[arg-type]
Expand Down Expand Up @@ -179,3 +236,81 @@ async def foo() -> None:

with pytest.raises(ValueError):
provide.parsed_fn_signature


@pytest.mark.asyncio
async def test_stateful_sync_generator_with_cleanup() -> None:
"""Verify that sync stateful callable instances maintain state and execute cleanup."""
factory = SyncGeneratorCallable()
provide = Provide(factory, sync_to_thread=None)

# Sanity check for detection
assert provide.has_sync_generator_dependency is True
assert provide.has_async_generator_dependency is False

# First call
gen1 = await provide()
assert isinstance(gen1, Generator)

session1 = next(gen1)
assert session1 == 1
assert factory.call_count == 1
assert factory.cleanup_count == 0

# Second call (state should be maintained)
gen2 = await provide()
assert isinstance(gen2, Generator)

session2 = next(gen2)
assert session2 == 2
assert factory.call_count == 2
assert factory.cleanup_count == 0

# Cleanup first generator
with pytest.raises(StopIteration):
next(gen1)
assert factory.cleanup_count == 1

# Cleanup second generator
with pytest.raises(StopIteration):
next(gen2)
assert factory.cleanup_count == 2


@pytest.mark.asyncio
async def test_stateful_async_generator_with_cleanup() -> None:
"""Verify that async stateful callable instances maintain state and execute cleanup."""
factory = AsyncGeneratorCallable()
provide = Provide(factory, sync_to_thread=None)

# Sanity check for detection
assert provide.has_sync_generator_dependency is False
assert provide.has_async_generator_dependency is True

# First call
gen1 = await provide()
assert isinstance(gen1, AsyncGenerator)

session1 = await gen1.__anext__()
assert session1 == 1
assert factory.call_count == 1
assert factory.cleanup_count == 0

# Second call (state should be maintained)
gen2 = await provide()
assert isinstance(gen2, AsyncGenerator)

session2 = await gen2.__anext__()
assert session2 == 2
assert factory.call_count == 2
assert factory.cleanup_count == 0

# Cleanup first generator
with pytest.raises(StopAsyncIteration):
await gen1.__anext__()
assert factory.cleanup_count == 1

# Cleanup second generator
with pytest.raises(StopAsyncIteration):
await gen2.__anext__()
assert factory.cleanup_count == 2
Loading