Skip to content

Commit bd25c36

Browse files
committed
adds support from sqllite and mysql in sqlalchemy
1 parent 517afaa commit bd25c36

File tree

6 files changed

+164
-30
lines changed

6 files changed

+164
-30
lines changed

.github/workflows/test_destinations_remote.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ jobs:
112112
- name: mssql
113113
destinations: "[\"mssql\"]"
114114
filesystem_drivers: "[\"memory\"]"
115-
extras: "--extra mssql --extra s3 --extra gs --extra az --extra parquet --extra adbc"
115+
extras: "--extra mssql --extra s3 --extra gs --extra az --extra parquet --group adbc"
116116
pre_install_commands: "sudo ACCEPT_EULA=Y apt-get install --yes msodbcsql18"
117117
post_install_commands: "uv run dbc install mssql"
118118
always_run_all_tests: true

dlt/destinations/_adbc_jobs.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dlt.common.destination.capabilities import LoaderFileFormatSelector
99
from dlt.common.schema.typing import TTableSchema
1010
from dlt.common.typing import TLoaderFileFormat
11+
from dlt.common.utils import without_none
1112
from dlt.destinations.job_client_impl import SqlJobClientBase
1213

1314
if TYPE_CHECKING:
@@ -20,11 +21,33 @@ class AdbcParquetCopyJob(RunnableLoadJob, ABC):
2021
def __init__(self, file_path: str) -> None:
2122
super().__init__(file_path)
2223
self._job_client: SqlJobClientBase = None
24+
# override default schema handling
25+
self._connect_catalog_name: str = None
26+
self._connect_schema_name: str = None
2327

2428
@abstractmethod
2529
def _connect(self) -> Connection:
2630
pass
2731

32+
def _set_catalog_and_schema(self) -> Tuple[str, str]:
33+
catalog_name = self._connect_catalog_name
34+
if catalog_name is None:
35+
catalog_name = self._job_client.sql_client.catalog_name(quote=False)
36+
elif catalog_name == "":
37+
# empty string disables catalog
38+
catalog_name = None
39+
40+
schema_name = self._connect_schema_name
41+
if schema_name is None:
42+
schema_name = self._job_client.sql_client.escape_column_name(
43+
self._job_client.sql_client.dataset_name, quote=False, casefold=True
44+
)
45+
elif schema_name == "":
46+
# empty string disables schema
47+
schema_name = None
48+
49+
return catalog_name, schema_name
50+
2851
def run(self) -> None:
2952
from dlt.common.libs.pyarrow import pq_stream_with_new_columns
3053
from dlt.common.libs.pyarrow import pyarrow
@@ -36,15 +59,15 @@ def _iter_batches(file_path: str) -> Iterator[pyarrow.RecordBatch]:
3659
with self._connect() as conn, conn.cursor() as cur:
3760
import time
3861

62+
catalog_name, schema_name = self._set_catalog_and_schema()
63+
kwargs = dict(catalog_name=catalog_name, db_schema_name=schema_name)
64+
3965
t_ = time.time()
4066
rows = cur.adbc_ingest(
4167
self.load_table_name,
4268
_iter_batches(self._file_path),
4369
mode="append",
44-
catalog_name=self._job_client.sql_client.catalog_name(quote=False),
45-
db_schema_name=self._job_client.sql_client.fully_qualified_dataset_name(
46-
quote=False
47-
),
70+
**without_none(kwargs), # type: ignore[arg-type]
4871
)
4972
conn.commit()
5073
logger.warning(

dlt/destinations/impl/mssql/mssql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _connect(self) -> "Connection":
9494
self._config = self._job_client.config # type: ignore[assignment]
9595
conn_dsn = self.odbc_to_go_mssql_dsn(self._config.credentials.get_odbc_dsn_dict())
9696
conn_str = ";".join([f"{k}={v}" for k, v in conn_dsn.items()])
97-
logger.warning(f"ADBC connecting to {conn_str}")
97+
logger.info(f"ADBC connect to {conn_str}")
9898
return dbapi.connect(driver="mssql", db_kwargs={"uri": conn_str})
9999

100100
@staticmethod

dlt/destinations/impl/sqlalchemy/load_jobs.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1+
from __future__ import annotations
2+
13
from typing import IO, Any, Dict, Iterator, List, Sequence, TYPE_CHECKING, Optional
24
import math
35

46
import sqlalchemy as sa
57

8+
from dlt.common import logger
69
from dlt.common.destination.client import (
710
RunnableLoadJob,
811
HasFollowupJobs,
912
PreparedTableSchema,
1013
)
1114
from dlt.common.storages import FileStorage
1215
from dlt.common.json import json, PY_DATETIME_DECODERS
13-
from dlt.destinations.sql_jobs import SqlFollowupJob
1416

17+
from dlt.destinations._adbc_jobs import AdbcParquetCopyJob
18+
from dlt.destinations.sql_jobs import SqlFollowupJob
1519
from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient
1620
from dlt.destinations.impl.sqlalchemy.merge_job import SqlalchemyMergeFollowupJob
1721

@@ -74,6 +78,90 @@ def run(self) -> None:
7478
_sql_client.execute_sql(table.insert(), chunk)
7579

7680

81+
class SqlalchemyParquetADBCJob(AdbcParquetCopyJob):
82+
"""ADBC Parquet copy job for SQLAlchemy (sqlite, mysql) with query param handling."""
83+
84+
def __init__(self, file_path: str, table: sa.Table) -> None:
85+
super().__init__(file_path)
86+
self._job_client: "SqlalchemyJobClient" = None
87+
self.table = table
88+
89+
if TYPE_CHECKING:
90+
from adbc_driver_manager.dbapi import Connection
91+
92+
def _connect(self) -> Connection:
93+
from adbc_driver_manager import dbapi
94+
95+
engine = self._job_client.config.credentials.engine
96+
dialect = engine.dialect.name.lower()
97+
url = engine.url
98+
99+
query = dict(url.query or {})
100+
101+
if dialect == "sqlite":
102+
# disable schema and catalog when ingest
103+
self._connect_schema_name = ""
104+
self._connect_catalog_name = ""
105+
106+
# attach directly to dataset sqlite file as "main"
107+
if self._job_client.sql_client.dataset_name == "main":
108+
db_path = url.database
109+
else:
110+
db_path = self._job_client.sql_client._sqlite_dataset_filename(
111+
self._job_client.sql_client.dataset_name
112+
)
113+
conn_str = f"file:{db_path}"
114+
115+
if query:
116+
qs = "&".join(f"{k}={v}" for k, v in query.items())
117+
conn_str = f"{conn_str}?{qs}"
118+
119+
logger.info(f"ADBC connect to {conn_str}")
120+
return dbapi.connect(driver="sqlite", db_kwargs={"uri": conn_str})
121+
122+
elif dialect == "mysql":
123+
# disable schema and catalog when ingest
124+
self._connect_schema_name = ""
125+
self._connect_catalog_name = ""
126+
127+
# mysql: convert SSL params into go-mysql ADBC parameters
128+
mapped = {}
129+
for k, v in query.items():
130+
lk = k.lower()
131+
if lk == "ssl_ca":
132+
mapped["tls-ca"] = v
133+
elif lk == "ssl_cert":
134+
mapped["tls-cert"] = v
135+
elif lk == "ssl_key":
136+
mapped["tls-key"] = v
137+
elif lk == "ssl_mode":
138+
mapped["tls"] = v
139+
else:
140+
mapped[k] = v
141+
142+
username = url.username or ""
143+
password = url.password or ""
144+
auth = f"{username}:{password}@" if username or password else ""
145+
146+
host = url.host or "localhost"
147+
port = url.port or 3306
148+
# dataset name is schema name is database name. each database is a schema in mysql
149+
database = self._job_client.sql_client.dataset_name # url.database or ""
150+
151+
base = f"{auth}tcp({host}:{port})/{database}"
152+
if mapped:
153+
qs = "&".join(f"{k}={v}" for k, v in mapped.items())
154+
conn_str = f"{base}?{qs}"
155+
else:
156+
conn_str = base
157+
158+
logger.info(f"ADBC connect to {conn_str}")
159+
return dbapi.connect(driver="mysql", db_kwargs={"uri": conn_str})
160+
161+
else:
162+
raise NotImplementedError(f"ADBC not supported for sqlalchemy dialect {dialect}")
163+
164+
77165
class SqlalchemyParquetInsertJob(SqlalchemyJsonLInsertJob):
78166
def _iter_data_item_chunks(self) -> Iterator[Sequence[Dict[str, Any]]]:
79167
from dlt.common.libs.pyarrow import ParquetFile

dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
PreparedTableSchema,
1515
FollowupJobRequest,
1616
)
17-
from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset, SqlLoadJob
1817
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
1918
from dlt.common.schema import Schema, TTableSchema, TColumnSchema, TSchemaTables
2019
from dlt.common.schema.typing import (
@@ -30,11 +29,15 @@
3029
get_columns_names_with_prop,
3130
)
3231
from dlt.common.storages.load_storage import ParsedLoadJobFileName
32+
33+
from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset
34+
from dlt.destinations._adbc_jobs import has_driver as adbc_has_driver
3335
from dlt.destinations.exceptions import DatabaseUndefinedRelation
3436
from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient
3537
from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyClientConfiguration
3638
from dlt.destinations.impl.sqlalchemy.load_jobs import (
3739
SqlalchemyJsonLInsertJob,
40+
SqlalchemyParquetADBCJob,
3841
SqlalchemyParquetInsertJob,
3942
SqlalchemyReplaceJob,
4043
SqlalchemyMergeFollowupJob,
@@ -138,7 +141,11 @@ def create_load_job(
138141
return SqlalchemyJsonLInsertJob(file_path, table_obj)
139142
elif parsed_file.file_format == "parquet":
140143
table_obj = self._to_table_object(table)
141-
return SqlalchemyParquetInsertJob(file_path, table_obj)
144+
# if driver for a given dialect is installed
145+
if adbc_has_driver(self.config.credentials.engine.dialect.name):
146+
return SqlalchemyParquetADBCJob(file_path, table_obj)
147+
else:
148+
return SqlalchemyParquetInsertJob(file_path, table_obj)
142149
return None
143150

144151
def complete_load(self, load_id: str) -> None:
Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
import dlt
4+
from dlt.common import Decimal
45

56
from tests.cases import table_update_and_row
67
from tests.load.pipeline.utils import get_load_package_jobs
@@ -10,46 +11,61 @@
1011
)
1112

1213

13-
# def test_adbc_detection() -> None:
14-
# from adbc_driver_manager import dbapi, ProgrammingError
15-
# import adbc_driver_manager as dm
14+
@pytest.mark.parametrize(
15+
"destination_config",
16+
destinations_configs(default_sql_configs=True, subset=["postgres", "mssql", "sqlalchemy"]),
17+
ids=lambda x: x.name,
18+
)
19+
def test_adbc_detection(destination_config: DestinationTestConfiguration) -> None:
20+
from dlt.destinations._adbc_jobs import has_driver
21+
22+
driver = destination_config.destination_name or destination_config.destination_type
23+
if driver == "postgres":
24+
driver = "postgresql"
25+
elif driver == "sqlalchemy_sqlite":
26+
driver = "sqlite"
27+
elif driver == "sqlalchemy_mysql":
28+
driver = "mysql"
1629

17-
# try:
18-
# db = dm.AdbcDatabase(driver="mssqll")
19-
# db.close()
20-
# # try:
21-
# # dbapi.connect(driver="postgresql", db_kwargs={"uri": "server"})
22-
# except ProgrammingError as pr_ex:
23-
# print(str(pr_ex))
24-
# print(pr_ex.sqlstate)
30+
assert has_driver(driver)[0] is True
2531

2632

2733
@pytest.mark.parametrize(
2834
"destination_config",
29-
destinations_configs(default_sql_configs=True, subset=["postgres", "mssql"]),
35+
destinations_configs(default_sql_configs=True, subset=["postgres", "mssql", "sqlalchemy"]),
3036
ids=lambda x: x.name,
3137
)
3238
def test_adbc_parquet_loading(destination_config: DestinationTestConfiguration) -> None:
33-
column_schemas, data_types = table_update_and_row()
39+
# if destination_config.destination_name == "sqlalchemy_sqlite":
40+
# pytest.skip("skip generic ADBC test for sqlite because just a few data types are supported")
41+
column_schemas, data_ = table_update_and_row()
3442

3543
pipeline = destination_config.setup_pipeline("pipeline_adbc", dev_mode=True)
3644

37-
# postgres
38-
del column_schemas["col6_precision"] # adbc cannot process decimal(6,2)
39-
# mssql
40-
del column_schemas["col7_precision"] # adbc cannot process fixed binary
45+
if destination_config.destination_type in ("postgres", "mssql"):
46+
del column_schemas["col11_precision"] # TIME(3) not supported
47+
if destination_config.destination_type == "postgres":
48+
del column_schemas["col6_precision"] # adbc cannot process decimal(6,2)
49+
else:
50+
del column_schemas["col7_precision"] # adbc cannot process fixed binary
4151

42-
# both
43-
del column_schemas["col11_precision"] # TIME(3) not supported
52+
if destination_config.destination_name == "sqlalchemy_sqlite":
53+
for k, v in column_schemas.items():
54+
# decimals not supported
55+
if v["data_type"] in ("decimal", "wei", "time"):
56+
data_[k] = str(data_[k])
57+
column_schemas[k]["data_type"] = "text"
4458

4559
@dlt.resource(file_format="parquet", columns=column_schemas, max_table_nesting=0)
4660
def complex_resource():
47-
yield data_types
61+
yield data_
4862

4963
info = pipeline.run(complex_resource())
5064
jobs = get_load_package_jobs(
5165
info.load_packages[0], "completed_jobs", "complex_resource", ".parquet"
5266
)
5367
# there must be a parquet job or adbc is not installed so we fall back to other job type
5468
assert len(jobs) == 1
55-
print(pipeline.dataset().table("complex_resource").fetchall())
69+
# make sure we can read data back. TODO: verify data types
70+
rows = pipeline.dataset().table("complex_resource").fetchall()
71+
assert len(rows) == 1

0 commit comments

Comments
 (0)