Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ pyperclip = ["pyperclip>=1.0.0"]
[dependency-groups]
typing = ["typing_extensions"]
docs = [
"checkedframe",
"ghp-import==2.1.0",
"pydata-sphinx-theme==0.16.1",
"sphinx==8.2.3",
Expand Down Expand Up @@ -53,6 +52,8 @@ profile = "black"
plugins = ["checkedframe.mypy"]
allow_redefinition = true

[tool.ruff]
ignore = ["E741"]

[tool.ruff.format]
docstring-code-format = true
107 changes: 101 additions & 6 deletions src/checkedframe/_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import functools
import inspect
from collections.abc import Collection, Sequence
from typing import Any, Callable, Literal, Optional, get_type_hints
from collections.abc import Collection, Iterable, Sequence
from typing import Any, Callable, Literal, Optional, get_args, get_type_hints

import narwhals.stable.v1 as nw
import narwhals.stable.v1.typing as nwt
from narwhals.stable.v1.dependencies import (
get_cudf,
get_modin,
Expand Down Expand Up @@ -261,7 +262,14 @@ def _is_sorted(s: nw.Series, descending: bool) -> bool:

def _is_id(df: nw.DataFrame, subset: str | list[str]) -> bool:
n_rows = df.shape[0]
n_unique_rows = df.select(subset).unique().shape[0]

# n_unique on dataframes is not available on narhwals, so if we have only one
# column specified as the subset, take a potential fast path, otherwise fallback to
# a generic version
if isinstance(subset, str):
n_unique_rows = df[subset].n_unique()
else:
n_unique_rows = df.select(subset).unique().shape[0]

return n_rows == n_unique_rows

Expand Down Expand Up @@ -364,6 +372,69 @@ def _str_contains(name: str, pattern: str, literal: bool = False) -> nw.Expr:
return nw.col(name).str.contains(pattern, literal=literal)


CardinalityRatio = Literal["1:1", "1:m", "m:1"]


def _cardinality_ratio(
df: nw.DataFrame,
left: str,
right: str,
cardinality: CardinalityRatio,
by: str | list[str] | None = None,
allow_duplicates: bool = False,
):
index_col = "__checkedframe_temp_cardinality_ratio_private_index__"
result_col = left

original_lf = df.with_row_index(index_col).lazy()

if by is None:
by = "__checkedframe_temp_cardinality_ratio_private_by__"
original_lf = original_lf.with_columns(nw.lit(1).alias(by))

if isinstance(by, str):
by = [by]

lf = original_lf.select(left, right, *by)

if allow_duplicates:
lf = lf.unique()

if cardinality == "1:1":
result_lf = (
lf.group_by(by)
.agg(
nw.col(left).n_unique().__eq__(nw.len()),
nw.col(right).n_unique().__eq__(nw.len()),
)
.select(*by, nw.col(left).__and__(nw.col(right)).alias(result_col))
)
elif cardinality == "1:m":
result_lf = (
lf.group_by(by)
.agg(nw.col(left).n_unique().__eq__(nw.len()).alias(result_col))
.select(*by, result_col)
)
elif cardinality == "m:1":
result_lf = (
lf.group_by(by)
.agg(nw.col(right).n_unique().__eq__(nw.len()).alias(result_col))
.select(*by, result_col)
)
else:
raise ValueError(
f"Invalid cardinality `{cardinality}`, must be one of `{get_args(CardinalityRatio)}`"
)

return (
original_lf.select(index_col, *by)
.join(result_lf, on=by, how="left")
.sort(index_col) # joins are not guaranteed to preserve order
.select(result_col)
.collect()[result_col]
)


CheckInputType = Optional[Literal["auto", "Frame", "str", "Series"]]
CheckReturnType = Literal["auto", "bool", "Expr", "Series"]

Expand Down Expand Up @@ -1370,9 +1441,10 @@ def is_id(subset: str | list[str]) -> Check:


class MySchema(cf.Schema):
__dataframe_checks__ = [cf.Check.is_id("group")]
group = cf.String()

_id_check = cf.Check.is_id("group")


df = pl.DataFrame({"group": ["A", "B", "A"]})
MySchema.validate(df)
Expand All @@ -1382,8 +1454,7 @@ class MySchema(cf.Schema):
.. code-block:: text

SchemaError: Found 1 error(s)
__dataframe__: 1 error(s)
- is_id failed: 'group' must uniquely identify the DataFrame
* is_id failed for 3 / 3 (100.00%) rows: group must uniquely identify the DataFrame

"""
return Check(
Expand All @@ -1394,3 +1465,27 @@ class MySchema(cf.Schema):
name="is_id",
description=f"{subset} must uniquely identify the DataFrame",
)

@staticmethod
def cardinality_ratio(
left: str,
right: str,
cardinality: CardinalityRatio,
by: str | list[str] | None = None,
allow_duplicates: bool = False,
) -> Check:
return Check(
func=functools.partial(
_cardinality_ratio,
left=left,
right=right,
cardinality=cardinality,
by=by,
allow_duplicates=allow_duplicates,
),
input_type="Frame",
return_type="Series",
native=False,
name="cardinality_ratio",
description=f"The relationship between {left} and {right} must be {cardinality} (by={by}, allow_duplicates={allow_duplicates})",
)
38 changes: 38 additions & 0 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import checkedframe as cf
from checkedframe.exceptions import SchemaError

ENGINES = [pd.DataFrame, pl.DataFrame]

Expand Down Expand Up @@ -338,3 +339,40 @@ class S(cf.Schema):
_c = cf.Check.is_id(["a", "b"])

S.validate(df)


@pytest.mark.parametrize("engine", ENGINES)
def test_cardinality_ratio(engine):
df = engine(
{
"feature": ["f1", "f1", "f1", "f2"],
"special_value": [-1, -6, -4, -7],
"imputed": [None, None, "MAX_WIN_P1", None],
"reason": ["o1", "o1", "o2", "o3"],
}
)

class S(cf.Schema):
_c = cf.Check.cardinality_ratio("imputed", "reason", "1:1", by="feature")

with pytest.raises(SchemaError):
S.validate(df)

class S(cf.Schema):
_c = cf.Check.cardinality_ratio("imputed", "reason", "1:m", by="feature")

with pytest.raises(SchemaError):
S.validate(df)

class S(cf.Schema):
_c = cf.Check.cardinality_ratio("imputed", "reason", "m:1", by="feature")

with pytest.raises(SchemaError):
S.validate(df)

class S(cf.Schema):
_c = cf.Check.cardinality_ratio(
"imputed", "reason", "m:1", by="feature", allow_duplicates=True
)

S.validate(df)
Loading