Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dependencies = ["narwhals>=1.0.0"]
[project.optional-dependencies]
pyperclip = ["pyperclip>=1.0.0"]

[dependency-groups]
typing = ["typing_extensions"]

[build-system]
requires = ["setuptools>=42", "wheel"]
Expand All @@ -27,3 +29,4 @@ profile = "black"

[tool.mypy]
plugins = ["checkedframe.mypy"]
allow_redefinition = true
1 change: 1 addition & 0 deletions src/checkedframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from checkedframe import exceptions, selectors

from ._checks import Check, col, lit
from ._config import Config, apply_configs
from ._core import Schema
from ._dtypes import (
Array,
Expand Down
49 changes: 49 additions & 0 deletions src/checkedframe/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import TYPE_CHECKING, TypedDict, Union

from .selectors import Selector, by_name

if TYPE_CHECKING:
from typing_extensions import Unpack

SelectorLike = Union[str, Iterable[str], Selector]


class _PossibleConfigs(TypedDict):
nullable: bool
required: bool
cast: bool
allow_nan: bool
allow_inf: bool


class Config:
def __init__(
self,
selector: SelectorLike,
**kwargs: Unpack[_PossibleConfigs],
):
if not isinstance(selector, Selector):
actual_selector = by_name(selector)
else:
actual_selector = selector

self.selector = actual_selector
self.dct = kwargs


# This just makes it easier to do isinstance checks
class ConfigList:
def __init__(self, *args: Config):
self.args = args


def apply_configs(*args: Config):
def decorator(cls):
cls.__private_checkedframe_config = ConfigList(*args)

return cls

return decorator
60 changes: 29 additions & 31 deletions src/checkedframe/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import narwhals.stable.v1 as nw
import narwhals.stable.v1.typing as nwt

from ._checks import Check, CheckInputType
from ._checks import Check
from ._config import ConfigList
from ._dtypes import CastError, _Column, _nw_type_to_cf_type, _TypedColumn
from ._utils import get_class_members
from .exceptions import SchemaError
Expand Down Expand Up @@ -153,7 +154,7 @@ class InterrogationResult:


def _private_interrogate(
schema: Schema, df: nwt.IntoDataFrameT, cast: bool
schema: Schema, df: nwt.IntoDataFrameT
) -> _PrivateInterrogationResult:
nw_df = nw.from_native(df, eager_only=True)
df_schema = nw_df.collect_schema() # type: ignore[attribute]
Expand Down Expand Up @@ -201,7 +202,7 @@ def _private_interrogate(
if actual_dtype == expected_col.to_narwhals():
pass
else:
if expected_col.cast or cast:
if expected_col.cast:
try:
nw_df = nw_df.with_columns(
_nw_type_to_cf_type(actual_dtype)._safe_cast(
Expand Down Expand Up @@ -351,9 +352,10 @@ def _private_interrogate(


def _interrogate(
schema: Schema, df: nwt.IntoDataFrameT, cast: bool
schema: Schema,
df: nwt.IntoDataFrameT,
) -> InterrogationResult:
res = _private_interrogate(schema=schema, df=df, cast=cast)
res = _private_interrogate(schema=schema, df=df)

return InterrogationResult(
df=res.df.to_native(),
Expand All @@ -365,8 +367,8 @@ def _interrogate(
)


def _filter(schema: Schema, df: nwt.IntoDataFrameT, cast: bool) -> nwt.IntoDataFrameT:
res = _private_interrogate(schema=schema, df=df, cast=cast)
def _filter(schema: Schema, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT:
res = _private_interrogate(schema=schema, df=df)

return res.df.filter(res.is_good).to_native()

Expand Down Expand Up @@ -430,8 +432,8 @@ def _wrap_err(e: str) -> str:
return "\n".join(error_summary + output)


def _validate(schema: Schema, df: nwt.IntoDataFrameT, cast: bool) -> nwt.IntoDataFrameT:
res = _private_interrogate(schema, df, cast)
def _validate(schema: Schema, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT:
res = _private_interrogate(schema, df)

if not res.is_good.all():
raise SchemaError(
Expand Down Expand Up @@ -569,33 +571,33 @@ def _parse_into_schema(cls) -> Schema:
else:
checks.append(val)

if isinstance(val, ConfigList):
for config in val.args:
for c in config.selector(schema_dict):
for k, v in config.dct.items():
setattr(schema_dict[c], k, v)

res = Schema(expected_schema=schema_dict, checks=checks)
cls._schema = res

return res

@classmethod
def interrogate(
cls, df: nwt.IntoDataFrameT, cast: bool = False
) -> InterrogationResult:
return _interrogate(cls._parse_into_schema(), df, cast)
def interrogate(cls, df: nwt.IntoDataFrameT) -> InterrogationResult:
return _interrogate(cls._parse_into_schema(), df)

def __interrogate(
self, df: nwt.IntoDataFrameT, cast: bool = False
) -> InterrogationResult:
return _interrogate(self, df, cast)
def __interrogate(self, df: nwt.IntoDataFrameT) -> InterrogationResult:
return _interrogate(self, df)

@classmethod
def validate(cls, df: nwt.IntoDataFrameT, cast: bool = False) -> nwt.IntoDataFrameT:
def validate(cls, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT:
"""Validate the given DataFrame

Parameters
----------
df : nwt.IntoDataFrameT
Any Narwhals-compatible DataFrame, see https://narwhals-dev.github.io/narwhals/
for more information
cast : bool, optional
Whether to cast columns, by default False

Returns
-------
Expand All @@ -607,18 +609,14 @@ def validate(cls, df: nwt.IntoDataFrameT, cast: bool = False) -> nwt.IntoDataFra
SchemaError
If validation fails
"""
return _validate(cls._parse_into_schema(), df, cast)
return _validate(cls._parse_into_schema(), df)

def __validate(
self, df: nwt.IntoDataFrameT, cast: bool = False
) -> nwt.IntoDataFrameT:
return _validate(self, df, cast)
def __validate(self, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT:
return _validate(self, df)

@classmethod
def filter(cls, df: nwt.IntoDataFrameT, cast: bool = False) -> nwt.IntoDataFrameT:
return _filter(cls._parse_into_schema(), df, cast)
def filter(cls, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT:
return _filter(cls._parse_into_schema(), df)

def __filter(
self, df: nwt.IntoDataFrameT, cast: bool = False
) -> nwt.IntoDataFrameT:
return _filter(self, df, cast)
def __filter(self, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT:
return _filter(self, df)
8 changes: 0 additions & 8 deletions src/checkedframe/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@ class _BoundedDType(_DType):
_max: int | float


class _ColumnKwargs(TypedDict):
name: Optional[str]
nullable: bool
required: bool
cast: bool
checks: Optional[list[Check]]


class _Column:
"""Represents a column in a DataFrame.

Expand Down
8 changes: 8 additions & 0 deletions src/checkedframe/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def __xor__(self, other: Selector) -> Selector:
lambda col, dtype: self.condition(col, dtype) != other.condition(col, dtype)
)

def exclude(self, other: str | Iterable[str] | Selector) -> Selector:
if not isinstance(other, Selector):
selector_other = by_name(other)
else:
selector_other = other

return self.__sub__(selector_other)


def _flatten_str_iterable(lst: Iterable[str | Iterable[str]]) -> list[str]:
res = []
Expand Down
24 changes: 24 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import checkedframe as cf


def test_config():
@cf.apply_configs(cf.Config(cf.selectors.float(), nullable=True))
class S(cf.Schema):
float1 = cf.Float64()
float2 = cf.Float64()
int1 = cf.Int64()

s = S._parse_into_schema().expected_schema

assert s["float1"].nullable
assert s["float2"].nullable
assert not s["int1"].nullable

class S2(S):
float3 = cf.Float64()

s2 = S2._parse_into_schema().expected_schema

assert s["float1"].nullable
assert s["float2"].nullable
assert s2["float3"].nullable
6 changes: 6 additions & 0 deletions tests/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,9 @@ def test_boolean():

def test_numeric():
assert set(cfs.numeric()(SCHEMA)) == set(["int8", "float32"])


def test_exclude():
assert cfs.numeric().exclude("int8")(SCHEMA) == ["float32"]
assert cfs.numeric().exclude(cfs.by_name("int8"))(SCHEMA) == ["float32"]
assert cfs.numeric().exclude(["int8"])(SCHEMA) == ["float32"]
58 changes: 36 additions & 22 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.