From 7616b5be7cbcab36b907e332a35499ccdacff254 Mon Sep 17 00:00:00 2001 From: Cangyuan Li Date: Mon, 29 Sep 2025 21:43:22 -0700 Subject: [PATCH] feat: allow users to not specify the return type --- docs/api/checks.rst | 4 +- src/checkedframe/_checks.py | 41 ++++--------- src/checkedframe/_core.py | 115 ++++++++++++++++++++---------------- tests/test_checks.py | 24 ++++++++ 4 files changed, 103 insertions(+), 81 deletions(-) diff --git a/docs/api/checks.rst b/docs/api/checks.rst index fb08ed3..3976fa2 100644 --- a/docs/api/checks.rst +++ b/docs/api/checks.rst @@ -131,7 +131,9 @@ While not particularly useful when operating on a single column, DataFrame-level @cf.Check(columns="checking_balance", input_type="Series", return_type="Series", native=False) def series_series_check(s): - return s <= 100 + return s <= 100 + + In addition, it is possible to omit the *return* type hint or the *return_type* parameter, in which case **checkedframe** will inspect the resulting object to try and infer the return type. However, it is required to type hint the *input* or specify the *input_type* parameter. Built-in Checks diff --git a/src/checkedframe/_checks.py b/src/checkedframe/_checks.py index ee2f981..bb46014 100644 --- a/src/checkedframe/_checks.py +++ b/src/checkedframe/_checks.py @@ -85,7 +85,7 @@ def _is_expr(x: Any) -> bool: def _is_dataframe(x: Any) -> bool: return ( - isinstance(x, nw.DataFrame) + issubclass(x, nw.DataFrame) or _is_polars_dataframe(x) or _is_pandas_dataframe(x) or _is_modin_dataframe(x) @@ -117,25 +117,13 @@ def _infer_input_type( return "auto" -def _infer_return_type( - type_hints: dict[str, Any], input_type: CheckInputType -) -> CheckReturnType: - try: - # Try to get it from the type hints first - type_hint = type_hints["return"] - - if issubclass(type_hint, bool): - return "bool" - elif _is_expr(type_hint): - return "Expr" - elif _is_series(type_hint): - return "Series" - except KeyError: - # If type hints don't exist, we try to infer from the input_type - pass - - if input_type == "str" or input_type is None: +def _infer_return_type(typ) -> CheckReturnType: + if issubclass(typ, bool): + return "bool" + elif _is_expr(typ): return "Expr" + elif _is_series(typ): + return "Series" return "auto" @@ -501,10 +489,12 @@ def _set_params(self) -> None: self.input_type = _infer_input_type(type_hints, signature) if auto_return_type: - self.return_type = _infer_return_type( - type_hints, - self.input_type, - ) + try: + self.return_type = _infer_return_type( + type_hints["return"], + ) + except KeyError: + self.return_type = "auto" if self.native == "auto": raise ValueError( @@ -516,11 +506,6 @@ def _set_params(self) -> None: f"Input type of `{self.name}` could not be automatically determined from context" ) - if self.return_type == "auto": - raise ValueError( - f"Return type of `{self.name}` could not be automatically determined from context" - ) - if self.name is None: self.name = None if self.func.__name__ == "" else self.func.__name__ diff --git a/src/checkedframe/_core.py b/src/checkedframe/_core.py index 77ca967..b147612 100644 --- a/src/checkedframe/_core.py +++ b/src/checkedframe/_core.py @@ -9,7 +9,7 @@ import narwhals.stable.v1 as nw import narwhals.stable.v1.typing as nwt -from ._checks import Check +from ._checks import Check, _infer_return_type from ._config import ConfigList from ._dtypes import CastError, CfUnion, TypedColumn, _nw_type_to_cf_type from ._utils import get_class_members @@ -75,67 +75,78 @@ def _run_check( {"check_name": check_name, "check_description": check.description} ) - if check.return_type == "Expr": - if check.input_type == "str": - expr = check.func(series_name) - else: - expr = check.func() + if check.input_type == "str": + res = check.func(series_name) + elif check.input_type is None: + res = check.func() + elif check.input_type == "Series": + if series_name is None: + raise ValueError( + "Series cannot be automatically determined in this context" + ) + + input_ = nw_df[series_name] + + if check.native: + input_ = input_.to_native() + + res = check.func(input_) + elif check.input_type == "Frame": + # mypy complains here that the input type is Series, not DataFrame, but it + # is only a Series if the above branch is hit, which means this branch is + # not + input_ = nw_df # type: ignore[assignment] + + if check.native: + input_ = input_.to_native() - assert isinstance(check.native, bool) + res = check.func(input_) + else: + # We should never hit this branch since the input type always needs to be + # specified statically + raise ValueError("Invalid input type") + + check_return_type = check.return_type + if check_return_type == "auto": + # We need to get the type (the class) instead of the instance + check_return_type = _infer_return_type(type(res)) + + if check_return_type == "Expr": return _ResultWrapper( - expr, + res, msg=err_msg, identifier=new_check_name, column=column_name, operation=check_name, - native=check.native, + # We have already transformed this to a boolean + native=check.native, # type: ignore[arg-type] + is_expr=True, + ) + elif check_return_type == "Series": + res = nw.from_native(res, series_only=True) + return _ResultWrapper( + res, + msg=err_msg, + identifier=new_check_name, + column=column_name, + operation=check_name, + native=False, + is_expr=False, + ) + elif check_return_type == "bool": + res = nw.lit(res) + return _ResultWrapper( + res, + msg=err_msg, + identifier=new_check_name, + column=column_name, + operation=check_name, + native=False, is_expr=True, ) else: - if check.input_type == "Series": - if series_name is None: - raise ValueError( - "Series cannot be automatically determined in this context" - ) - - input_ = nw_df[series_name] - elif check.input_type == "Frame": - # mypy complains here that the input type is Series, not DataFrame, but it - # is only a Series if the above branch is hit, which means this branch is - # not - - input_ = nw_df # type: ignore[assignment] - else: - raise ValueError("Invalid input type") - - if check.native: - input_ = input_.to_native() - - res = check.func(input_) - - if check.return_type == "Series": - res = nw.from_native(res, series_only=True) - return _ResultWrapper( - res, - msg=err_msg, - identifier=new_check_name, - column=column_name, - operation=check_name, - native=False, - is_expr=False, - ) - elif check.return_type == "bool": - res = nw.lit(res) - return _ResultWrapper( - res, - msg=err_msg, - identifier=new_check_name, - column=column_name, - operation=check_name, - native=False, - is_expr=True, - ) + raise ValueError(f"Invalid return_type {check_return_type}") @dataclasses.dataclass diff --git a/tests/test_checks.py b/tests/test_checks.py index 96fcc70..1b69f75 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -91,6 +91,30 @@ def a_check3(s: pd.Series) -> bool: assert col_checks["a_check3"].return_type == "bool" +def test_raises_if_type_inference_fails(): + with pytest.raises(ValueError): + + class S(cf.Schema): + a = cf.String() + + @cf.Check(columns="a") + def invalid_input(s) -> pl.Series: + return s.str.len_chars() < 3 + + +def test_type_inference_from_object(): + class S(cf.Schema): + a = cf.String() + + @cf.Check(columns="a") + def a_check(s: pl.Series): + return s.str.len_chars().lt(3) + + df = pl.DataFrame({"a": ["a", "b", "c"]}) + + S.validate(df) + + @pytest.mark.parametrize("engine", ENGINES) def test_is_between(engine): df = engine({"a": [1, 2, 3]})