Skip to content

Commit

Permalink
Showing 9 changed files with 114 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/integrations/prefect-dbt/prefect_dbt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import _version

from .core import PrefectDbtSettings
from .cloud import DbtCloudCredentials, DbtCloudJob # noqa
from .cli import ( # noqa
DbtCliProfile,
33 changes: 22 additions & 11 deletions src/integrations/prefect-dbt/prefect_dbt/cli/configs/base.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from pathlib import Path
from typing import Any, Dict, Optional, Type

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self

from prefect.blocks.core import Block
@@ -40,15 +40,15 @@ class DbtConfigs(Block, abc.ABC):

def _populate_configs_json(
self,
configs_json: Dict[str, Any],
fields: Dict[str, Any],
model: BaseModel = None,
) -> Dict[str, Any]:
configs_json: dict[str, Any],
fields: dict[str, Any],
model: Optional[BaseModel] = None,
) -> dict[str, Any]:
"""
Recursively populate configs_json.
"""
# if allow_field_overrides is True keys from TargetConfigs take precedence
override_configs_json = {}
override_configs_json: dict[str, Any] = {}

for field_name, field in fields.items():
if model is not None:
@@ -93,7 +93,7 @@ def _populate_configs_json(
configs_json.update(override_configs_json)
return configs_json

def get_configs(self) -> Dict[str, Any]:
def get_configs(self) -> dict[str, Any]:
"""
Returns the dbt configs, likely used eventually for writing to profiles.yml.
@@ -120,6 +120,19 @@ class BaseTargetConfigs(DbtConfigs, abc.ABC):
),
)

@model_validator(mode="before")
@classmethod
def handle_target_configs(cls, v: Any) -> Any:
"""Handle target configs field aliasing during validation"""
if isinstance(v, dict):
if "schema_" in v:
v["schema"] = v.pop("schema_")
# Handle nested blocks
for value in v.values():
if isinstance(value, dict) and "schema_" in value:
value["schema"] = value.pop("schema_")
return v


class TargetConfigs(BaseTargetConfigs):
"""
@@ -289,7 +302,7 @@ class GlobalConfigs(DbtConfigs):
write_json: Optional[bool] = Field(
default=None,
description=(
"Determines whether dbt writes JSON artifacts to " "the target/ directory."
"Determines whether dbt writes JSON artifacts to the target/ directory."
),
)
warn_error: Optional[bool] = Field(
@@ -321,9 +334,7 @@ class GlobalConfigs(DbtConfigs):
)
use_experimental_parser: Optional[bool] = Field(
default=None,
description=(
"Opt into the latest experimental version " "of the static parser."
),
description=("Opt into the latest experimental version of the static parser."),
)
static_parser: Optional[bool] = Field(
default=None,
29 changes: 22 additions & 7 deletions src/integrations/prefect-dbt/prefect_dbt/cli/credentials.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Module containing credentials for interacting with dbt CLI"""

from typing import Any, Dict, Optional, Union
from typing import Annotated, Any, Dict, Optional, Union

from pydantic import Field
from pydantic import Discriminator, Field, Tag

from prefect.blocks.core import Block
from prefect_dbt.cli.configs import GlobalConfigs, TargetConfigs
@@ -23,6 +23,18 @@
PostgresTargetConfigs = None


def target_configs_discriminator(v: Any) -> str:
"""
Discriminator function for target configs. Returns the block type slug.
"""
if isinstance(v, dict):
return v.get("block_type_slug", "dbt-cli-target-configs")
if isinstance(v, Block):
# When creating a new instance, we get a concrete Block type
return v.get_block_type_slug()
return "dbt-cli-target-configs" # Default to base type


class DbtCliProfile(Block):
"""
Profile for use across dbt CLI tasks and flows.
@@ -116,11 +128,14 @@ class DbtCliProfile(Block):
target: str = Field(
default=..., description="The default target your dbt project will use."
)
target_configs: Union[
SnowflakeTargetConfigs,
BigQueryTargetConfigs,
PostgresTargetConfigs,
TargetConfigs,
target_configs: Annotated[
Union[
Annotated[SnowflakeTargetConfigs, Tag("dbt-cli-snowflake-target-configs")],
Annotated[BigQueryTargetConfigs, Tag("dbt-cli-bigquery-target-configs")],
Annotated[PostgresTargetConfigs, Tag("dbt-cli-postgres-target-configs")],
Annotated[TargetConfigs, Tag("dbt-cli-target-configs")],
],
Discriminator(target_configs_discriminator),
] = Field(
default=...,
description=(
3 changes: 3 additions & 0 deletions src/integrations/prefect-dbt/prefect_dbt/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from prefect_dbt.core.settings import PrefectDbtSettings

__all__ = ["PrefectDbtSettings"]
17 changes: 17 additions & 0 deletions src/integrations/prefect-dbt/prefect_dbt/core/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
A class for configuring or automatically discovering settings to be used with PrefectDbtRunner.
"""

from pathlib import Path

from dbt_common.events.base_types import EventLevel
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict


class PrefectDbtSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="DBT_")

profiles_dir: Path = Field(default=Path.home() / ".dbt")
project_dir: Path = Field(default_factory=Path.cwd)
log_level: EventLevel = Field(default=EventLevel.INFO)
11 changes: 6 additions & 5 deletions src/integrations/prefect-dbt/tests/cli/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import os
from pathlib import Path
from unittest.mock import MagicMock

@@ -409,12 +408,14 @@ def test_flow():


@pytest.mark.usefixtures("dbt_runner_ls_result")
def test_trigger_dbt_cli_command_find_env(profiles_dir, dbt_cli_profile_bare):
def test_trigger_dbt_cli_command_find_env(
profiles_dir, dbt_cli_profile_bare, monkeypatch
):
@flow
def test_flow():
return trigger_dbt_cli_command("ls", dbt_cli_profile=dbt_cli_profile_bare)

os.environ["DBT_PROFILES_DIR"] = str(profiles_dir)
monkeypatch.setenv("DBT_PROFILES_DIR", str(profiles_dir))
result = test_flow()
assert isinstance(result, dbtRunnerResult)

@@ -474,9 +475,9 @@ def dbt_cli_profile(self):
)

def test_find_valid_profiles_dir_default_env(
self, tmp_path, mock_open_process, mock_shell_process
self, tmp_path, mock_open_process, mock_shell_process, monkeypatch
):
os.environ["DBT_PROFILES_DIR"] = str(tmp_path)
monkeypatch.setenv("DBT_PROFILES_DIR", str(tmp_path))
(tmp_path / "profiles.yml").write_text("test")
DbtCoreOperation(commands=["dbt debug"]).run()
actual = str(mock_open_process.call_args_list[0][1]["env"]["DBT_PROFILES_DIR"])
7 changes: 5 additions & 2 deletions src/integrations/prefect-dbt/tests/cli/test_credentials.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from prefect_dbt.cli.credentials import DbtCliProfile, GlobalConfigs, TargetConfigs
from pydantic import ValidationError
from typing_extensions import Literal


@pytest.mark.parametrize("configs_type", ["dict", "model"])
def test_dbt_cli_profile_init(configs_type):
def test_dbt_cli_profile_init(configs_type: Literal["dict", "model"]):
target_configs = dict(type="snowflake", schema="schema")
global_configs = dict(use_colors=False)
if configs_type == "model":
@@ -60,7 +61,9 @@ def test_dbt_cli_profile_get_profile():
"class_target_configs",
],
)
async def test_dbt_cli_profile_save_load_roundtrip(target_configs_request, request):
async def test_dbt_cli_profile_save_load_roundtrip(
target_configs_request: str, request: pytest.FixtureRequest
):
target_configs = request.getfixturevalue(target_configs_request)
dbt_cli_profile = DbtCliProfile(
name="my_name",
37 changes: 37 additions & 0 deletions src/integrations/prefect-dbt/tests/core/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pathlib import Path

from dbt_common.events.base_types import EventLevel
from prefect_dbt.core.settings import PrefectDbtSettings
from pytest import MonkeyPatch


def test_default_settings():
settings = PrefectDbtSettings()
assert settings.profiles_dir == Path.home() / ".dbt"
assert settings.project_dir == Path.cwd()
assert settings.log_level == EventLevel.INFO


def test_custom_settings():
custom_profiles_dir = Path("/custom/profiles/dir")
custom_project_dir = Path("/custom/project/dir")

settings = PrefectDbtSettings(
profiles_dir=custom_profiles_dir, project_dir=custom_project_dir
)

assert settings.profiles_dir == custom_profiles_dir
assert settings.project_dir == custom_project_dir


def test_env_var_override(monkeypatch: MonkeyPatch):
env_profiles_dir = "/env/profiles/dir"
env_project_dir = "/env/project/dir"

monkeypatch.setenv("DBT_PROFILES_DIR", env_profiles_dir)
monkeypatch.setenv("DBT_PROJECT_DIR", env_project_dir)

settings = PrefectDbtSettings()

assert settings.profiles_dir == Path(env_profiles_dir)
assert settings.project_dir == Path(env_project_dir)
2 changes: 1 addition & 1 deletion src/prefect/blocks/core.py
Original file line number Diff line number Diff line change
@@ -669,7 +669,7 @@ def _generate_code_example(cls) -> str:
module_str = ".".join(qualified_name.split(".")[:-1])
origin = cls.__pydantic_generic_metadata__.get("origin") or cls
class_name = origin.__name__
block_variable_name = f'{cls.get_block_type_slug().replace("-", "_")}_block'
block_variable_name = f"{cls.get_block_type_slug().replace('-', '_')}_block"

return dedent(
f"""\

0 comments on commit d5279c5

Please sign in to comment.