From 75da3d0777c866cf71a73ac8ed0f784b9a6d6af8 Mon Sep 17 00:00:00 2001 From: Cangyuan Li Date: Thu, 10 Jul 2025 21:39:39 -0700 Subject: [PATCH 1/2] use config instead of passing options to validate --- pyproject.toml | 3 ++ src/checkedframe/__init__.py | 1 + src/checkedframe/_config.py | 49 +++++++++++++++++++++++++++++ src/checkedframe/_core.py | 60 +++++++++++++++++------------------- src/checkedframe/_dtypes.py | 8 ----- tests/test_config.py | 24 +++++++++++++++ uv.lock | 58 +++++++++++++++++++++------------- 7 files changed, 142 insertions(+), 61 deletions(-) create mode 100644 src/checkedframe/_config.py create mode 100644 tests/test_config.py diff --git a/pyproject.toml b/pyproject.toml index 62d8223..5cd3c82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,8 @@ dependencies = ["narwhals>=1.0.0"] [project.optional-dependencies] pyperclip = ["pyperclip>=1.0.0"] +[dependency-groups] +typing = ["typing_extensions"] [build-system] requires = ["setuptools>=42", "wheel"] @@ -27,3 +29,4 @@ profile = "black" [tool.mypy] plugins = ["checkedframe.mypy"] +allow_redefinition = true diff --git a/src/checkedframe/__init__.py b/src/checkedframe/__init__.py index 9923e7e..142074a 100644 --- a/src/checkedframe/__init__.py +++ b/src/checkedframe/__init__.py @@ -2,6 +2,7 @@ from checkedframe import exceptions, selectors from ._checks import Check, col, lit +from ._config import Config, apply_configs from ._core import Schema from ._dtypes import ( Array, diff --git a/src/checkedframe/_config.py b/src/checkedframe/_config.py new file mode 100644 index 0000000..70ef4a3 --- /dev/null +++ b/src/checkedframe/_config.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, TypedDict, Union + +from .selectors import Selector, by_name + +if TYPE_CHECKING: + from typing_extensions import Unpack + +SelectorLike = Union[str, Iterable[str], Selector] + + +class _PossibleConfigs(TypedDict): + nullable: bool + required: bool + cast: bool + allow_nan: bool + allow_inf: bool + + +class Config: + def __init__( + self, + selector: SelectorLike, + **kwargs: Unpack[_PossibleConfigs], + ): + if not isinstance(selector, Selector): + actual_selector = by_name(selector) + else: + actual_selector = selector + + self.selector = actual_selector + self.dct = kwargs + + +# This just makes it easier to do isinstance checks +class ConfigList: + def __init__(self, *args: Config): + self.args = args + + +def apply_configs(*args: Config): + def decorator(cls): + cls.__private_checkedframe_config = ConfigList(*args) + + return cls + + return decorator diff --git a/src/checkedframe/_core.py b/src/checkedframe/_core.py index f31ee66..9d337cb 100644 --- a/src/checkedframe/_core.py +++ b/src/checkedframe/_core.py @@ -8,7 +8,8 @@ import narwhals.stable.v1 as nw import narwhals.stable.v1.typing as nwt -from ._checks import Check, CheckInputType +from ._checks import Check +from ._config import ConfigList from ._dtypes import CastError, _Column, _nw_type_to_cf_type, _TypedColumn from ._utils import get_class_members from .exceptions import SchemaError @@ -153,7 +154,7 @@ class InterrogationResult: def _private_interrogate( - schema: Schema, df: nwt.IntoDataFrameT, cast: bool + schema: Schema, df: nwt.IntoDataFrameT ) -> _PrivateInterrogationResult: nw_df = nw.from_native(df, eager_only=True) df_schema = nw_df.collect_schema() # type: ignore[attribute] @@ -201,7 +202,7 @@ def _private_interrogate( if actual_dtype == expected_col.to_narwhals(): pass else: - if expected_col.cast or cast: + if expected_col.cast: try: nw_df = nw_df.with_columns( _nw_type_to_cf_type(actual_dtype)._safe_cast( @@ -351,9 +352,10 @@ def _private_interrogate( def _interrogate( - schema: Schema, df: nwt.IntoDataFrameT, cast: bool + schema: Schema, + df: nwt.IntoDataFrameT, ) -> InterrogationResult: - res = _private_interrogate(schema=schema, df=df, cast=cast) + res = _private_interrogate(schema=schema, df=df) return InterrogationResult( df=res.df.to_native(), @@ -365,8 +367,8 @@ def _interrogate( ) -def _filter(schema: Schema, df: nwt.IntoDataFrameT, cast: bool) -> nwt.IntoDataFrameT: - res = _private_interrogate(schema=schema, df=df, cast=cast) +def _filter(schema: Schema, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT: + res = _private_interrogate(schema=schema, df=df) return res.df.filter(res.is_good).to_native() @@ -430,8 +432,8 @@ def _wrap_err(e: str) -> str: return "\n".join(error_summary + output) -def _validate(schema: Schema, df: nwt.IntoDataFrameT, cast: bool) -> nwt.IntoDataFrameT: - res = _private_interrogate(schema, df, cast) +def _validate(schema: Schema, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT: + res = _private_interrogate(schema, df) if not res.is_good.all(): raise SchemaError( @@ -569,24 +571,26 @@ def _parse_into_schema(cls) -> Schema: else: checks.append(val) + if isinstance(val, ConfigList): + for config in val.args: + for c in config.selector(schema_dict): + for k, v in config.dct.items(): + setattr(schema_dict[c], k, v) + res = Schema(expected_schema=schema_dict, checks=checks) cls._schema = res return res @classmethod - def interrogate( - cls, df: nwt.IntoDataFrameT, cast: bool = False - ) -> InterrogationResult: - return _interrogate(cls._parse_into_schema(), df, cast) + def interrogate(cls, df: nwt.IntoDataFrameT) -> InterrogationResult: + return _interrogate(cls._parse_into_schema(), df) - def __interrogate( - self, df: nwt.IntoDataFrameT, cast: bool = False - ) -> InterrogationResult: - return _interrogate(self, df, cast) + def __interrogate(self, df: nwt.IntoDataFrameT) -> InterrogationResult: + return _interrogate(self, df) @classmethod - def validate(cls, df: nwt.IntoDataFrameT, cast: bool = False) -> nwt.IntoDataFrameT: + def validate(cls, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT: """Validate the given DataFrame Parameters @@ -594,8 +598,6 @@ def validate(cls, df: nwt.IntoDataFrameT, cast: bool = False) -> nwt.IntoDataFra df : nwt.IntoDataFrameT Any Narwhals-compatible DataFrame, see https://narwhals-dev.github.io/narwhals/ for more information - cast : bool, optional - Whether to cast columns, by default False Returns ------- @@ -607,18 +609,14 @@ def validate(cls, df: nwt.IntoDataFrameT, cast: bool = False) -> nwt.IntoDataFra SchemaError If validation fails """ - return _validate(cls._parse_into_schema(), df, cast) + return _validate(cls._parse_into_schema(), df) - def __validate( - self, df: nwt.IntoDataFrameT, cast: bool = False - ) -> nwt.IntoDataFrameT: - return _validate(self, df, cast) + def __validate(self, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT: + return _validate(self, df) @classmethod - def filter(cls, df: nwt.IntoDataFrameT, cast: bool = False) -> nwt.IntoDataFrameT: - return _filter(cls._parse_into_schema(), df, cast) + def filter(cls, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT: + return _filter(cls._parse_into_schema(), df) - def __filter( - self, df: nwt.IntoDataFrameT, cast: bool = False - ) -> nwt.IntoDataFrameT: - return _filter(self, df, cast) + def __filter(self, df: nwt.IntoDataFrameT) -> nwt.IntoDataFrameT: + return _filter(self, df) diff --git a/src/checkedframe/_dtypes.py b/src/checkedframe/_dtypes.py index c6f7840..afd361e 100644 --- a/src/checkedframe/_dtypes.py +++ b/src/checkedframe/_dtypes.py @@ -34,14 +34,6 @@ class _BoundedDType(_DType): _max: int | float -class _ColumnKwargs(TypedDict): - name: Optional[str] - nullable: bool - required: bool - cast: bool - checks: Optional[list[Check]] - - class _Column: """Represents a column in a DataFrame. diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..3b3e49c --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,24 @@ +import checkedframe as cf + + +def test_config(): + @cf.apply_configs(cf.Config(cf.selectors.float(), nullable=True)) + class S(cf.Schema): + float1 = cf.Float64() + float2 = cf.Float64() + int1 = cf.Int64() + + s = S._parse_into_schema().expected_schema + + assert s["float1"].nullable + assert s["float2"].nullable + assert not s["int1"].nullable + + class S2(S): + float3 = cf.Float64() + + s2 = S2._parse_into_schema().expected_schema + + assert s["float1"].nullable + assert s["float2"].nullable + assert s2["float3"].nullable diff --git a/uv.lock b/uv.lock index 4c54ed5..d007d92 100644 --- a/uv.lock +++ b/uv.lock @@ -1,43 +1,57 @@ version = 1 revision = 2 -requires-python = ">=3.8" -resolution-markers = [ - "python_full_version >= '3.9'", - "python_full_version < '3.9'", -] +requires-python = ">=3.9" [[package]] name = "checkedframe" -version = "0.0.5" +version = "0.0.9" source = { editable = "." } dependencies = [ - { name = "narwhals", version = "1.42.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "narwhals", version = "1.43.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "narwhals" }, + { name = "typing-extensions" }, ] -[package.metadata] -requires-dist = [{ name = "narwhals", specifier = ">=1.0.0" }] +[package.optional-dependencies] +pyperclip = [ + { name = "pyperclip" }, +] -[[package]] -name = "narwhals" -version = "1.42.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", +[package.dev-dependencies] +typing = [ + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/df/d6/168a787b7800d6c89846b791e4f5ee6b94998a80c8c2838a019d3d71984d/narwhals-1.42.1.tar.gz", hash = "sha256:50a5635b11aeda98cf9c37e839fd34b0a24159f59a4dfae930290ad698320494", size = 492865, upload-time = "2025-06-12T15:15:13.222Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/79/3f/8d450588206b437dd239a6d44230c63095e71135bd95d5a74347d07adbd5/narwhals-1.42.1-py3-none-any.whl", hash = "sha256:7a270d44b94ccdb277a799ae890c42e8504c537c1849f195eb14717c6184977a", size = 359888, upload-time = "2025-06-12T15:15:11.643Z" }, + +[package.metadata] +requires-dist = [ + { name = "narwhals", specifier = ">=1.0.0" }, + { name = "pyperclip", marker = "extra == 'pyperclip'", specifier = ">=1.0.0" }, + { name = "typing-extensions", specifier = ">=4.14.1" }, ] +provides-extras = ["pyperclip"] + +[package.metadata.requires-dev] +typing = [{ name = "typing-extensions" }] [[package]] name = "narwhals" version = "1.43.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.9'", -] sdist = { url = "https://files.pythonhosted.org/packages/61/82/9f351a79260a6456db3f53d248268b4c3791f1e3228eec3c745e8816afd6/narwhals-1.43.1.tar.gz", hash = "sha256:6ff56d600da67a0a0980b83bd5577d076772fdba96474076ba4e76c920dbc1e5", size = 496655, upload-time = "2025-06-19T09:37:56.398Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/8f/1e/b741d4eabbde95b1790e7df3c33c6b19f9b48db98a1416c6a6f06572bc66/narwhals-1.43.1-py3-none-any.whl", hash = "sha256:1ee508fa4dc0e05aa5b88717ba11613d8d9ccf0dd1e48513d4a3afb237dba9f2", size = 362737, upload-time = "2025-06-19T09:37:54.415Z" }, ] + +[[package]] +name = "pyperclip" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/23/2f0a3efc4d6a32f3b63cdff36cd398d9701d26cda58e3ab97ac79fb5e60d/pyperclip-1.9.0.tar.gz", hash = "sha256:b7de0142ddc81bfc5c7507eea19da920b92252b548b96186caf94a5e2527d310", size = 20961, upload-time = "2024-06-18T20:38:48.401Z" } + +[[package]] +name = "typing-extensions" +version = "4.14.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/5a/da40306b885cc8c09109dc2e1abd358d5684b1425678151cdaed4731c822/typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36", size = 107673, upload-time = "2025-07-04T13:28:34.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" }, +] From fb03ffdf572ccd3354cd654c2404565bdef26f6d Mon Sep 17 00:00:00 2001 From: Cangyuan Li Date: Thu, 10 Jul 2025 21:39:48 -0700 Subject: [PATCH 2/2] add exclude to selectors --- src/checkedframe/selectors.py | 8 ++++++++ tests/test_selectors.py | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/src/checkedframe/selectors.py b/src/checkedframe/selectors.py index 3e64435..9a3fa48 100644 --- a/src/checkedframe/selectors.py +++ b/src/checkedframe/selectors.py @@ -41,6 +41,14 @@ def __xor__(self, other: Selector) -> Selector: lambda col, dtype: self.condition(col, dtype) != other.condition(col, dtype) ) + def exclude(self, other: str | Iterable[str] | Selector) -> Selector: + if not isinstance(other, Selector): + selector_other = by_name(other) + else: + selector_other = other + + return self.__sub__(selector_other) + def _flatten_str_iterable(lst: Iterable[str | Iterable[str]]) -> list[str]: res = [] diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 32c9b57..d1d60ec 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -72,3 +72,9 @@ def test_boolean(): def test_numeric(): assert set(cfs.numeric()(SCHEMA)) == set(["int8", "float32"]) + + +def test_exclude(): + assert cfs.numeric().exclude("int8")(SCHEMA) == ["float32"] + assert cfs.numeric().exclude(cfs.by_name("int8"))(SCHEMA) == ["float32"] + assert cfs.numeric().exclude(["int8"])(SCHEMA) == ["float32"]