Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/api/checks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ While not particularly useful when operating on a single column, DataFrame-level

@cf.Check(columns="checking_balance", input_type="Series", return_type="Series", native=False)
def series_series_check(s):
return s <= 100
return s <= 100

In addition, it is possible to omit the *return* type hint or the *return_type* parameter, in which case **checkedframe** will inspect the resulting object to try and infer the return type. However, it is required to type hint the *input* or specify the *input_type* parameter.


Built-in Checks
Expand Down
41 changes: 13 additions & 28 deletions src/checkedframe/_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _is_expr(x: Any) -> bool:

def _is_dataframe(x: Any) -> bool:
return (
isinstance(x, nw.DataFrame)
issubclass(x, nw.DataFrame)
or _is_polars_dataframe(x)
or _is_pandas_dataframe(x)
or _is_modin_dataframe(x)
Expand Down Expand Up @@ -117,25 +117,13 @@ def _infer_input_type(
return "auto"


def _infer_return_type(
type_hints: dict[str, Any], input_type: CheckInputType
) -> CheckReturnType:
try:
# Try to get it from the type hints first
type_hint = type_hints["return"]

if issubclass(type_hint, bool):
return "bool"
elif _is_expr(type_hint):
return "Expr"
elif _is_series(type_hint):
return "Series"
except KeyError:
# If type hints don't exist, we try to infer from the input_type
pass

if input_type == "str" or input_type is None:
def _infer_return_type(typ) -> CheckReturnType:
if issubclass(typ, bool):
return "bool"
elif _is_expr(typ):
return "Expr"
elif _is_series(typ):
return "Series"

return "auto"

Expand Down Expand Up @@ -501,10 +489,12 @@ def _set_params(self) -> None:
self.input_type = _infer_input_type(type_hints, signature)

if auto_return_type:
self.return_type = _infer_return_type(
type_hints,
self.input_type,
)
try:
self.return_type = _infer_return_type(
type_hints["return"],
)
except KeyError:
self.return_type = "auto"

if self.native == "auto":
raise ValueError(
Expand All @@ -516,11 +506,6 @@ def _set_params(self) -> None:
f"Input type of `{self.name}` could not be automatically determined from context"
)

if self.return_type == "auto":
raise ValueError(
f"Return type of `{self.name}` could not be automatically determined from context"
)

if self.name is None:
self.name = None if self.func.__name__ == "<lambda>" else self.func.__name__

Expand Down
115 changes: 63 additions & 52 deletions src/checkedframe/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import narwhals.stable.v1 as nw
import narwhals.stable.v1.typing as nwt

from ._checks import Check
from ._checks import Check, _infer_return_type
from ._config import ConfigList
from ._dtypes import CastError, CfUnion, TypedColumn, _nw_type_to_cf_type
from ._utils import get_class_members
Expand Down Expand Up @@ -75,67 +75,78 @@ def _run_check(
{"check_name": check_name, "check_description": check.description}
)

if check.return_type == "Expr":
if check.input_type == "str":
expr = check.func(series_name)
else:
expr = check.func()
if check.input_type == "str":
res = check.func(series_name)
elif check.input_type is None:
res = check.func()
elif check.input_type == "Series":
if series_name is None:
raise ValueError(
"Series cannot be automatically determined in this context"
)

input_ = nw_df[series_name]

if check.native:
input_ = input_.to_native()

res = check.func(input_)
elif check.input_type == "Frame":
# mypy complains here that the input type is Series, not DataFrame, but it
# is only a Series if the above branch is hit, which means this branch is
# not
input_ = nw_df # type: ignore[assignment]

if check.native:
input_ = input_.to_native()

assert isinstance(check.native, bool)
res = check.func(input_)
else:
# We should never hit this branch since the input type always needs to be
# specified statically
raise ValueError("Invalid input type")

check_return_type = check.return_type

if check_return_type == "auto":
# We need to get the type (the class) instead of the instance
check_return_type = _infer_return_type(type(res))

if check_return_type == "Expr":
return _ResultWrapper(
expr,
res,
msg=err_msg,
identifier=new_check_name,
column=column_name,
operation=check_name,
native=check.native,
# We have already transformed this to a boolean
native=check.native, # type: ignore[arg-type]
is_expr=True,
)
elif check_return_type == "Series":
res = nw.from_native(res, series_only=True)
return _ResultWrapper(
res,
msg=err_msg,
identifier=new_check_name,
column=column_name,
operation=check_name,
native=False,
is_expr=False,
)
elif check_return_type == "bool":
res = nw.lit(res)
return _ResultWrapper(
res,
msg=err_msg,
identifier=new_check_name,
column=column_name,
operation=check_name,
native=False,
is_expr=True,
)
else:
if check.input_type == "Series":
if series_name is None:
raise ValueError(
"Series cannot be automatically determined in this context"
)

input_ = nw_df[series_name]
elif check.input_type == "Frame":
# mypy complains here that the input type is Series, not DataFrame, but it
# is only a Series if the above branch is hit, which means this branch is
# not

input_ = nw_df # type: ignore[assignment]
else:
raise ValueError("Invalid input type")

if check.native:
input_ = input_.to_native()

res = check.func(input_)

if check.return_type == "Series":
res = nw.from_native(res, series_only=True)
return _ResultWrapper(
res,
msg=err_msg,
identifier=new_check_name,
column=column_name,
operation=check_name,
native=False,
is_expr=False,
)
elif check.return_type == "bool":
res = nw.lit(res)
return _ResultWrapper(
res,
msg=err_msg,
identifier=new_check_name,
column=column_name,
operation=check_name,
native=False,
is_expr=True,
)
raise ValueError(f"Invalid return_type {check_return_type}")


@dataclasses.dataclass
Expand Down
24 changes: 24 additions & 0 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ def a_check3(s: pd.Series) -> bool:
assert col_checks["a_check3"].return_type == "bool"


def test_raises_if_type_inference_fails():
with pytest.raises(ValueError):

class S(cf.Schema):
a = cf.String()

@cf.Check(columns="a")
def invalid_input(s) -> pl.Series:
return s.str.len_chars() < 3


def test_type_inference_from_object():
class S(cf.Schema):
a = cf.String()

@cf.Check(columns="a")
def a_check(s: pl.Series):
return s.str.len_chars().lt(3)

df = pl.DataFrame({"a": ["a", "b", "c"]})

S.validate(df)


@pytest.mark.parametrize("engine", ENGINES)
def test_is_between(engine):
df = engine({"a": [1, 2, 3]})
Expand Down