Skip to content

Commit b8917bc

Browse files
GWealecopybara-github
authored andcommitted
fix: Handle SQLite URLs in SqliteSessionService
The SqliteSessionService now accepts database paths in the form of SQLite URLs (e.g., "sqlite:///./sessions.db", "sqlite+aiosqlite:////absolute.db") Close #4077 Co-authored-by: George Weale <[email protected]> PiperOrigin-RevId: 853922433
1 parent 3c51ee7 commit b8917bc

File tree

2 files changed

+92
-7
lines changed

2 files changed

+92
-7
lines changed

src/google/adk/sessions/sqlite_session_service.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import time
2323
from typing import Any
2424
from typing import Optional
25+
from urllib.parse import unquote
26+
from urllib.parse import urlparse
2527
import uuid
2628

2729
import aiosqlite
@@ -91,6 +93,42 @@
9193
])
9294

9395

96+
def _parse_db_path(db_path: str) -> tuple[str, str, bool]:
97+
"""Normalizes a SQLite db path from a URL or filesystem path.
98+
99+
Returns:
100+
A tuple of:
101+
- filesystem path (for `os.path.exists` and user-facing messages)
102+
- value to pass to sqlite/aiosqlite connect
103+
- whether to pass `uri=True` to sqlite/aiosqlite connect
104+
105+
Notes:
106+
When a SQLAlchemy-style SQLite URL is provided, this follows SQLAlchemy's
107+
conventions:
108+
- `sqlite:///relative.db` is a path relative to the current working dir.
109+
- `sqlite:////absolute.db` is an absolute filesystem path.
110+
"""
111+
if not db_path.startswith(("sqlite:", "sqlite+aiosqlite:")):
112+
return db_path, db_path, False
113+
114+
parsed = urlparse(db_path)
115+
raw_path = unquote(parsed.path)
116+
if not raw_path:
117+
return db_path, db_path, False
118+
119+
normalized_path = raw_path
120+
if normalized_path.startswith("//"):
121+
normalized_path = normalized_path[1:]
122+
elif normalized_path.startswith("/"):
123+
normalized_path = normalized_path[1:]
124+
125+
if parsed.query:
126+
# sqlite3 only treats the filename as a URI when it starts with `file:`.
127+
return normalized_path, f"file:{normalized_path}?{parsed.query}", True
128+
129+
return normalized_path, normalized_path, False
130+
131+
94132
class SqliteSessionService(BaseSessionService):
95133
"""A session service that uses an SQLite database for storage via aiosqlite.
96134
@@ -100,17 +138,19 @@ class SqliteSessionService(BaseSessionService):
100138

101139
def __init__(self, db_path: str):
102140
"""Initializes the SQLite session service with a database path."""
103-
self._db_path = db_path
141+
self._db_path, self._db_connect_path, self._db_connect_uri = _parse_db_path(
142+
db_path
143+
)
104144

105145
if self._is_migration_needed():
106146
raise RuntimeError(
107-
f"Database {db_path} seems to use an old schema."
147+
f"Database {self._db_path} seems to use an old schema."
108148
" Please run the migration command to"
109149
" migrate it to the new schema. Example: `python -m"
110150
" google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite"
111-
f" --source_db_path {db_path} --dest_db_path"
112-
f" {db_path}.new` then backup {db_path} and rename"
113-
f" {db_path}.new to {db_path}."
151+
f" --source_db_path {self._db_path} --dest_db_path"
152+
f" {self._db_path}.new` then backup {self._db_path} and rename"
153+
f" {self._db_path}.new to {self._db_path}."
114154
)
115155

116156
@override
@@ -415,7 +455,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
415455
@asynccontextmanager
416456
async def _get_db_connection(self):
417457
"""Connects to the db and performs initial setup."""
418-
async with aiosqlite.connect(self._db_path) as db:
458+
async with aiosqlite.connect(
459+
self._db_connect_path, uri=self._db_connect_uri
460+
) as db:
419461
db.row_factory = aiosqlite.Row
420462
await db.execute(PRAGMA_FOREIGN_KEYS)
421463
await db.executescript(CREATE_SCHEMA_SQL)
@@ -514,7 +556,9 @@ def _is_migration_needed(self) -> bool:
514556
if not os.path.exists(self._db_path):
515557
return False
516558
try:
517-
with sqlite3.connect(self._db_path) as conn:
559+
with sqlite3.connect(
560+
self._db_connect_path, uri=self._db_connect_uri
561+
) as conn:
518562
cursor = conn.cursor()
519563
# Check if events table exists
520564
cursor.execute(

tests/unittests/sessions/test_session_service.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from datetime import datetime
1616
from datetime import timezone
1717
import enum
18+
import sqlite3
1819

1920
from google.adk.errors.already_exists_error import AlreadyExistsError
2021
from google.adk.events.event import Event
@@ -60,6 +61,46 @@ async def session_service(request, tmp_path):
6061
await service.close()
6162

6263

64+
@pytest.mark.asyncio
65+
async def test_sqlite_session_service_accepts_sqlite_urls(
66+
tmp_path, monkeypatch
67+
):
68+
monkeypatch.chdir(tmp_path)
69+
70+
service = SqliteSessionService('sqlite+aiosqlite:///./sessions.db')
71+
await service.create_session(app_name='app', user_id='user')
72+
assert (tmp_path / 'sessions.db').exists()
73+
74+
service = SqliteSessionService('sqlite:///./sessions2.db')
75+
await service.create_session(app_name='app', user_id='user')
76+
assert (tmp_path / 'sessions2.db').exists()
77+
78+
79+
@pytest.mark.asyncio
80+
async def test_sqlite_session_service_preserves_uri_query_parameters(
81+
tmp_path, monkeypatch
82+
):
83+
monkeypatch.chdir(tmp_path)
84+
db_path = tmp_path / 'readonly.db'
85+
with sqlite3.connect(db_path) as conn:
86+
conn.execute('CREATE TABLE IF NOT EXISTS t (id INTEGER)')
87+
conn.commit()
88+
89+
service = SqliteSessionService(f'sqlite+aiosqlite:///{db_path}?mode=ro')
90+
# `mode=ro` opens the DB read-only; schema creation should fail.
91+
with pytest.raises(sqlite3.OperationalError, match=r'readonly'):
92+
await service.create_session(app_name='app', user_id='user')
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_sqlite_session_service_accepts_absolute_sqlite_urls(tmp_path):
97+
abs_db_path = tmp_path / 'absolute.db'
98+
abs_url = 'sqlite+aiosqlite:////' + str(abs_db_path).lstrip('/')
99+
service = SqliteSessionService(abs_url)
100+
await service.create_session(app_name='app', user_id='user')
101+
assert abs_db_path.exists()
102+
103+
63104
@pytest.mark.asyncio
64105
async def test_get_empty_session(session_service):
65106
assert not await session_service.get_session(

0 commit comments

Comments
 (0)