diff --git a/CHANGELOG.md b/CHANGELOG.md index 86cd245..1e6b8e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +- Add dbt project for testing +- Add `--dbt-project-dir` pytest flag that points to the dbt project directory + ## [0.1.0.dev2] - 2022-01-28 - Implement use as pytest plugin diff --git a/setup.cfg b/setup.cfg index b66f150..d87fa08 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ test = [options.entry_points] pytest11 = - dbt=pytest_dbt_core.fixtures + dbt=pytest_dbt_core.plugin [flake8] ignore = E226,E302,E41,W504,W503 @@ -60,6 +60,7 @@ addopts = --cov=src --doctest-glob="README.md" --doctest-modules --ignore=scripts/ + --dbt-project-dir="./tests/dbt_project" spark_options = spark.app.name: dbt-core spark.executor.instances: 1 @@ -78,6 +79,7 @@ setenv = PIP_DISABLE_PIP_VERSION_CHECK = 1 COVERAGE_FILE = {env:COVERAGE_FILE:{toxworkdir}/.coverage.{envname}} {py27,pypy}: PYTHONWARNINGS=ignore:DEPRECATION::pip._internal.cli.base_command + DBT_PROFILES_DIR = {env:DBT_PROFILES_DIR:{toxinidir}/tests/dbt_project} passenv = PYTEST_* PIP_CACHE_DIR diff --git a/src/pytest_dbt_core/fixtures.py b/src/pytest_dbt_core/fixtures.py index adddc47..42a6b91 100644 --- a/src/pytest_dbt_core/fixtures.py +++ b/src/pytest_dbt_core/fixtures.py @@ -15,8 +15,6 @@ from dbt.parser.manifest import ManifestLoader from dbt.tracking import User -from .session import _SparkConnectionManager - from dbt.adapters.factory import ( # isort:skip AdapterContainer, get_adapter, @@ -37,21 +35,26 @@ class Args: arguments here. """ - project_dir: str = os.getcwd() + project_dir: str @pytest.fixture -def config() -> RuntimeConfig: +def config(request: SubRequest) -> RuntimeConfig: """ Get the (runtime) config. + Parameters + ---------- + request : SubRequest + The pytest request. + Returns ------- RuntimeConfig The runtime config. """ - # requires a profile in your project wich also exists in your profiles file - config = RuntimeConfig.from_args(Args()) + project_dir = request.config.getoption("--dbt-project-dir") + config = RuntimeConfig.from_args(Args(project_dir=project_dir)) return config @@ -72,12 +75,7 @@ def adapter(config: RuntimeConfig) -> AdapterContainer: """ register_adapter(config) adapter = get_adapter(config) - - connection_manager = _SparkConnectionManager(adapter.config) - adapter.connections = connection_manager - adapter.acquire_connection() - return adapter diff --git a/src/pytest_dbt_core/plugin.py b/src/pytest_dbt_core/plugin.py new file mode 100644 index 0000000..2732e96 --- /dev/null +++ b/src/pytest_dbt_core/plugin.py @@ -0,0 +1,31 @@ +"""The entrypoint for the plugin.""" + +import os + +from _pytest.config.argparsing import Parser + +from .fixtures import adapter, config, macro_generator, manifest + +__all__ = ( + "adapter", + "config", + "macro_generator", + "manifest", +) + + +def pytest_addoption(parser: Parser) -> None: + """ + Add pytest option. + + Parameters + ---------- + parser : Parser + The parser. + """ + parser.addoption( + "--dbt-project-dir", + help="The dbt project directory.", + type="string", + default=os.getcwd(), + ) diff --git a/tests/dbt_project/dbt_project.yml b/tests/dbt_project/dbt_project.yml new file mode 100644 index 0000000..35f853a --- /dev/null +++ b/tests/dbt_project/dbt_project.yml @@ -0,0 +1,6 @@ +name: dbt_project +profile: dbt_project +version: '0.3.0' +config-version: 2 +require-dbt-version: [">=1.0.0", "<2.0.0"] +macro-paths: ["macros"] diff --git a/tests/dbt_project/macros/fetch_single_statement.sql b/tests/dbt_project/macros/fetch_single_statement.sql new file mode 100644 index 0000000..d3506b7 --- /dev/null +++ b/tests/dbt_project/macros/fetch_single_statement.sql @@ -0,0 +1,13 @@ +{% macro fetch_single_statement(statement, default_value="") %} + + {% set results = run_query(statement) %} + + {% if execute %} + {% set value = results.columns[0].values()[0] %} + {% else %} + {% set value = default_value %} + {% endif %} + + {{ return( value ) }} + +{% endmacro %} diff --git a/tests/dbt_project/macros/prices.sql b/tests/dbt_project/macros/prices.sql new file mode 100644 index 0000000..8b50334 --- /dev/null +++ b/tests/dbt_project/macros/prices.sql @@ -0,0 +1,3 @@ +{% macro to_cents(column_name) %} + {{ column_name }} * 100 +{% endmacro %} diff --git a/tests/dbt_project/profiles.yml b/tests/dbt_project/profiles.yml new file mode 100644 index 0000000..0574136 --- /dev/null +++ b/tests/dbt_project/profiles.yml @@ -0,0 +1,8 @@ +dbt_project: + target: test + outputs: + test: + type: spark + method: session + schema: test + host: NA # not used, but required by `dbt-core` diff --git a/tests/dbt_project/tests/test_fetch_single_statement.py b/tests/dbt_project/tests/test_fetch_single_statement.py new file mode 100644 index 0000000..ddd3380 --- /dev/null +++ b/tests/dbt_project/tests/test_fetch_single_statement.py @@ -0,0 +1,15 @@ +import pytest +from dbt.clients.jinja import MacroGenerator +from pyspark.sql import SparkSession + + +@pytest.mark.parametrize( + "macro_generator", + ["macro.dbt_project.fetch_single_statement"], + indirect=True, +) +def test_create_table( + spark_session: SparkSession, macro_generator: MacroGenerator +) -> None: + out = macro_generator("SELECT 1") + assert out == 1 diff --git a/tests/dbt_project/tests/test_prices.py b/tests/dbt_project/tests/test_prices.py new file mode 100644 index 0000000..811e56a --- /dev/null +++ b/tests/dbt_project/tests/test_prices.py @@ -0,0 +1,20 @@ +import pytest +from dbt.clients.jinja import MacroGenerator +from pyspark.sql import SparkSession + + +@pytest.mark.parametrize( + "macro_generator", + ["macro.dbt_project.to_cents"], + indirect=True, +) +def test_create_table( + spark_session: SparkSession, macro_generator: MacroGenerator +) -> None: + expected = spark_session.createDataFrame([{"cents": 1000}]) + to_cents = macro_generator("price") + out = spark_session.sql( + "with data AS (SELECT 10 AS price) " + f"SELECT cast({to_cents} AS bigint) AS cents FROM data" + ) + assert out.collect() == expected.collect()