Skip to content

Commit 4ddb2cb

Browse files
GWealecopybara-github
authored andcommitted
chore: Close database engines to avoid aiosqlite pytest hangs
Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 852428755
1 parent 8789ad8 commit 4ddb2cb

File tree

4 files changed

+109
-222
lines changed

4 files changed

+109
-222
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ classifiers = [ # List of https://pypi.org/classifiers/
2626
dependencies = [
2727
# go/keep-sorted start
2828
"PyYAML>=6.0.2, <7.0.0", # For APIHubToolset.
29-
# TODO: Update aiosqlite version once https://github.com/omnilib/aiosqlite/issues/369 is fixed.
30-
"aiosqlite==0.21.0", # For SQLite database
29+
"aiosqlite>=0.21.0", # For SQLite database
3130
"anyio>=4.9.0, <5.0.0", # For MCP Session Manager
3231
"authlib>=1.5.1, <2.0.0", # For RestAPI Tool
3332
"click>=8.1.8, <9.0.0", # For CLI tools
@@ -110,6 +109,7 @@ eval = [
110109
"google-cloud-aiplatform[evaluation]>=1.100.0",
111110
"pandas>=2.2.3",
112111
"rouge-score>=0.1.2",
112+
"scipy<1.16; python_version<'3.11'",
113113
"tabulate>=0.9.0",
114114
# go/keep-sorted end
115115
]

src/google/adk/sessions/database_session_service.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
from sqlalchemy import event
2626
from sqlalchemy import select
2727
from sqlalchemy import text
28+
from sqlalchemy.engine import make_url
2829
from sqlalchemy.exc import ArgumentError
2930
from sqlalchemy.ext.asyncio import async_sessionmaker
3031
from sqlalchemy.ext.asyncio import AsyncEngine
3132
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
3233
from sqlalchemy.ext.asyncio import create_async_engine
3334
from sqlalchemy.inspection import inspect
35+
from sqlalchemy.pool import StaticPool
3436
from typing_extensions import override
3537
from tzlocal import get_localzone
3638

@@ -103,7 +105,15 @@ def __init__(self, db_url: str, **kwargs: Any):
103105
# 2. Create all tables based on schema
104106
# 3. Initialize all properties
105107
try:
106-
db_engine = create_async_engine(db_url, **kwargs)
108+
engine_kwargs = dict(kwargs)
109+
url = make_url(db_url)
110+
if url.get_backend_name() == "sqlite" and url.database == ":memory:":
111+
engine_kwargs.setdefault("poolclass", StaticPool)
112+
connect_args = dict(engine_kwargs.get("connect_args", {}))
113+
connect_args.setdefault("check_same_thread", False)
114+
engine_kwargs["connect_args"] = connect_args
115+
116+
db_engine = create_async_engine(db_url, **engine_kwargs)
107117
if db_engine.dialect.name == "sqlite":
108118
# Set sqlite pragma to enable foreign keys constraints
109119
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)
@@ -477,3 +487,15 @@ async def append_event(self, session: Session, event: Event) -> Event:
477487
# Also update the in-memory session
478488
await super().append_event(session=session, event=event)
479489
return event
490+
491+
async def close(self) -> None:
492+
"""Disposes the SQLAlchemy engine and closes pooled connections."""
493+
await self.db_engine.dispose()
494+
495+
async def __aenter__(self) -> DatabaseSessionService:
496+
"""Enters the async context manager and returns this service."""
497+
return self
498+
499+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
500+
"""Exits the async context manager and closes the service."""
501+
await self.close()

tests/unittests/sessions/migration/test_database_schema.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,20 @@ async def create_v0_db(db_path):
2929
await engine.dispose()
3030

3131

32+
# Use async context managers so DatabaseSessionService always closes.
33+
34+
3235
@pytest.mark.asyncio
3336
async def test_new_db_uses_latest_schema(tmp_path):
3437
db_path = tmp_path / 'new_db.db'
3538
db_url = f'sqlite+aiosqlite:///{db_path}'
36-
session_service = DatabaseSessionService(db_url)
37-
assert session_service._db_schema_version is None
38-
await session_service.create_session(app_name='my_app', user_id='test_user')
39-
assert (
40-
session_service._db_schema_version
41-
== _schema_check_utils.LATEST_SCHEMA_VERSION
42-
)
39+
async with DatabaseSessionService(db_url) as session_service:
40+
assert session_service._db_schema_version is None
41+
await session_service.create_session(app_name='my_app', user_id='test_user')
42+
assert (
43+
session_service._db_schema_version
44+
== _schema_check_utils.LATEST_SCHEMA_VERSION
45+
)
4346

4447
# Verify metadata table
4548
engine = create_async_engine(db_url)
@@ -71,21 +74,20 @@ async def test_existing_v0_db_uses_v0_schema(tmp_path):
7174
db_path = tmp_path / 'v0_db.db'
7275
await create_v0_db(db_path)
7376
db_url = f'sqlite+aiosqlite:///{db_path}'
74-
session_service = DatabaseSessionService(db_url)
75-
76-
assert session_service._db_schema_version is None
77-
await session_service.create_session(
78-
app_name='my_app', user_id='test_user', session_id='s1'
79-
)
80-
assert (
81-
session_service._db_schema_version
82-
== _schema_check_utils.SCHEMA_VERSION_0_PICKLE
83-
)
84-
85-
session = await session_service.get_session(
86-
app_name='my_app', user_id='test_user', session_id='s1'
87-
)
88-
assert session.id == 's1'
77+
async with DatabaseSessionService(db_url) as session_service:
78+
assert session_service._db_schema_version is None
79+
await session_service.create_session(
80+
app_name='my_app', user_id='test_user', session_id='s1'
81+
)
82+
assert (
83+
session_service._db_schema_version
84+
== _schema_check_utils.SCHEMA_VERSION_0_PICKLE
85+
)
86+
87+
session = await session_service.get_session(
88+
app_name='my_app', user_id='test_user', session_id='s1'
89+
)
90+
assert session.id == 's1'
8991

9092
# Verify schema tables
9193
engine = create_async_engine(db_url)
@@ -111,38 +113,38 @@ async def test_existing_latest_db_uses_latest_schema(tmp_path):
111113
db_url = f'sqlite+aiosqlite:///{db_path}'
112114

113115
# Create session service which creates db with latest schema
114-
session_service1 = DatabaseSessionService(db_url)
115-
await session_service1.create_session(
116-
app_name='my_app', user_id='test_user', session_id='s1'
117-
)
118-
assert (
119-
session_service1._db_schema_version
120-
== _schema_check_utils.LATEST_SCHEMA_VERSION
121-
)
122-
123-
# Create another session service on same db and check it detects latest schema
124-
session_service2 = DatabaseSessionService(db_url)
125-
await session_service2.create_session(
126-
app_name='my_app', user_id='test_user2', session_id='s2'
127-
)
128-
assert (
129-
session_service2._db_schema_version
130-
== _schema_check_utils.LATEST_SCHEMA_VERSION
131-
)
132-
s2 = await session_service2.get_session(
133-
app_name='my_app', user_id='test_user2', session_id='s2'
134-
)
135-
assert s2.id == 's2'
136-
137-
s1 = await session_service2.get_session(
138-
app_name='my_app', user_id='test_user', session_id='s1'
139-
)
140-
assert s1.id == 's1'
141-
142-
list_sessions_response = await session_service2.list_sessions(
143-
app_name='my_app'
144-
)
145-
assert len(list_sessions_response.sessions) == 2
116+
async with DatabaseSessionService(db_url) as session_service1:
117+
await session_service1.create_session(
118+
app_name='my_app', user_id='test_user', session_id='s1'
119+
)
120+
assert (
121+
session_service1._db_schema_version
122+
== _schema_check_utils.LATEST_SCHEMA_VERSION
123+
)
124+
125+
# Create another session service on same db and check it detects latest schema
126+
async with DatabaseSessionService(db_url) as session_service2:
127+
await session_service2.create_session(
128+
app_name='my_app', user_id='test_user2', session_id='s2'
129+
)
130+
assert (
131+
session_service2._db_schema_version
132+
== _schema_check_utils.LATEST_SCHEMA_VERSION
133+
)
134+
s2 = await session_service2.get_session(
135+
app_name='my_app', user_id='test_user2', session_id='s2'
136+
)
137+
assert s2.id == 's2'
138+
139+
s1 = await session_service2.get_session(
140+
app_name='my_app', user_id='test_user', session_id='s1'
141+
)
142+
assert s1.id == 's1'
143+
144+
list_sessions_response = await session_service2.list_sessions(
145+
app_name='my_app'
146+
)
147+
assert len(list_sessions_response.sessions) == 2
146148

147149
# Verify schema tables
148150
engine = create_async_engine(db_url)

0 commit comments

Comments
 (0)