diff --git a/pyproject.toml b/pyproject.toml index 7b736cc..077fd46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ pyperclip = ["pyperclip>=1.0.0"] [dependency-groups] typing = ["typing_extensions"] docs = [ - "checkedframe", "ghp-import==2.1.0", "pydata-sphinx-theme==0.16.1", "sphinx==8.2.3", @@ -53,6 +52,8 @@ profile = "black" plugins = ["checkedframe.mypy"] allow_redefinition = true +[tool.ruff] +ignore = ["E741"] [tool.ruff.format] docstring-code-format = true diff --git a/src/checkedframe/_checks.py b/src/checkedframe/_checks.py index e7b9bdc..11bab51 100644 --- a/src/checkedframe/_checks.py +++ b/src/checkedframe/_checks.py @@ -2,10 +2,11 @@ import functools import inspect -from collections.abc import Collection, Sequence -from typing import Any, Callable, Literal, Optional, get_type_hints +from collections.abc import Collection, Iterable, Sequence +from typing import Any, Callable, Literal, Optional, get_args, get_type_hints import narwhals.stable.v1 as nw +import narwhals.stable.v1.typing as nwt from narwhals.stable.v1.dependencies import ( get_cudf, get_modin, @@ -261,7 +262,14 @@ def _is_sorted(s: nw.Series, descending: bool) -> bool: def _is_id(df: nw.DataFrame, subset: str | list[str]) -> bool: n_rows = df.shape[0] - n_unique_rows = df.select(subset).unique().shape[0] + + # n_unique on dataframes is not available on narhwals, so if we have only one + # column specified as the subset, take a potential fast path, otherwise fallback to + # a generic version + if isinstance(subset, str): + n_unique_rows = df[subset].n_unique() + else: + n_unique_rows = df.select(subset).unique().shape[0] return n_rows == n_unique_rows @@ -364,6 +372,69 @@ def _str_contains(name: str, pattern: str, literal: bool = False) -> nw.Expr: return nw.col(name).str.contains(pattern, literal=literal) +CardinalityRatio = Literal["1:1", "1:m", "m:1"] + + +def _cardinality_ratio( + df: nw.DataFrame, + left: str, + right: str, + cardinality: CardinalityRatio, + by: str | list[str] | None = None, + allow_duplicates: bool = False, +): + index_col = "__checkedframe_temp_cardinality_ratio_private_index__" + result_col = left + + original_lf = df.with_row_index(index_col).lazy() + + if by is None: + by = "__checkedframe_temp_cardinality_ratio_private_by__" + original_lf = original_lf.with_columns(nw.lit(1).alias(by)) + + if isinstance(by, str): + by = [by] + + lf = original_lf.select(left, right, *by) + + if allow_duplicates: + lf = lf.unique() + + if cardinality == "1:1": + result_lf = ( + lf.group_by(by) + .agg( + nw.col(left).n_unique().__eq__(nw.len()), + nw.col(right).n_unique().__eq__(nw.len()), + ) + .select(*by, nw.col(left).__and__(nw.col(right)).alias(result_col)) + ) + elif cardinality == "1:m": + result_lf = ( + lf.group_by(by) + .agg(nw.col(left).n_unique().__eq__(nw.len()).alias(result_col)) + .select(*by, result_col) + ) + elif cardinality == "m:1": + result_lf = ( + lf.group_by(by) + .agg(nw.col(right).n_unique().__eq__(nw.len()).alias(result_col)) + .select(*by, result_col) + ) + else: + raise ValueError( + f"Invalid cardinality `{cardinality}`, must be one of `{get_args(CardinalityRatio)}`" + ) + + return ( + original_lf.select(index_col, *by) + .join(result_lf, on=by, how="left") + .sort(index_col) # joins are not guaranteed to preserve order + .select(result_col) + .collect()[result_col] + ) + + CheckInputType = Optional[Literal["auto", "Frame", "str", "Series"]] CheckReturnType = Literal["auto", "bool", "Expr", "Series"] @@ -1370,9 +1441,10 @@ def is_id(subset: str | list[str]) -> Check: class MySchema(cf.Schema): - __dataframe_checks__ = [cf.Check.is_id("group")] group = cf.String() + _id_check = cf.Check.is_id("group") + df = pl.DataFrame({"group": ["A", "B", "A"]}) MySchema.validate(df) @@ -1382,8 +1454,7 @@ class MySchema(cf.Schema): .. code-block:: text SchemaError: Found 1 error(s) - __dataframe__: 1 error(s) - - is_id failed: 'group' must uniquely identify the DataFrame + * is_id failed for 3 / 3 (100.00%) rows: group must uniquely identify the DataFrame """ return Check( @@ -1394,3 +1465,27 @@ class MySchema(cf.Schema): name="is_id", description=f"{subset} must uniquely identify the DataFrame", ) + + @staticmethod + def cardinality_ratio( + left: str, + right: str, + cardinality: CardinalityRatio, + by: str | list[str] | None = None, + allow_duplicates: bool = False, + ) -> Check: + return Check( + func=functools.partial( + _cardinality_ratio, + left=left, + right=right, + cardinality=cardinality, + by=by, + allow_duplicates=allow_duplicates, + ), + input_type="Frame", + return_type="Series", + native=False, + name="cardinality_ratio", + description=f"The relationship between {left} and {right} must be {cardinality} (by={by}, allow_duplicates={allow_duplicates})", + ) diff --git a/tests/test_checks.py b/tests/test_checks.py index 535c4d5..96fcc70 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -4,6 +4,7 @@ import pytest import checkedframe as cf +from checkedframe.exceptions import SchemaError ENGINES = [pd.DataFrame, pl.DataFrame] @@ -338,3 +339,40 @@ class S(cf.Schema): _c = cf.Check.is_id(["a", "b"]) S.validate(df) + + +@pytest.mark.parametrize("engine", ENGINES) +def test_cardinality_ratio(engine): + df = engine( + { + "feature": ["f1", "f1", "f1", "f2"], + "special_value": [-1, -6, -4, -7], + "imputed": [None, None, "MAX_WIN_P1", None], + "reason": ["o1", "o1", "o2", "o3"], + } + ) + + class S(cf.Schema): + _c = cf.Check.cardinality_ratio("imputed", "reason", "1:1", by="feature") + + with pytest.raises(SchemaError): + S.validate(df) + + class S(cf.Schema): + _c = cf.Check.cardinality_ratio("imputed", "reason", "1:m", by="feature") + + with pytest.raises(SchemaError): + S.validate(df) + + class S(cf.Schema): + _c = cf.Check.cardinality_ratio("imputed", "reason", "m:1", by="feature") + + with pytest.raises(SchemaError): + S.validate(df) + + class S(cf.Schema): + _c = cf.Check.cardinality_ratio( + "imputed", "reason", "m:1", by="feature", allow_duplicates=True + ) + + S.validate(df)