From 03788098a89cf8d047bc39f31761d97869a9a18b Mon Sep 17 00:00:00 2001 From: Cangyuan Li Date: Wed, 17 Sep 2025 18:25:27 -0700 Subject: [PATCH] fix: move away from post-hoc string processing that fixes a lot of recursive reprs --- src/checkedframe/_dtypes.py | 144 +++++++++++++++++++++++++ src/checkedframe/_schema_generation.py | 4 +- tests/test_schema_parsing.py | 38 +++++-- 3 files changed, 174 insertions(+), 12 deletions(-) diff --git a/src/checkedframe/_dtypes.py b/src/checkedframe/_dtypes.py index 8e7e432..8e50faa 100644 --- a/src/checkedframe/_dtypes.py +++ b/src/checkedframe/_dtypes.py @@ -29,6 +29,10 @@ def to_narwhals() -> NarwhalsDType | type[NarwhalsDType]: ... @abstractmethod def _safe_cast(s: nw.Series, to_dtype: _DType) -> nw.Series: ... + @staticmethod + @abstractmethod + def _to_repr(prefix: str = "") -> str: ... + class _BoundedDType(_DType): _min: int | float @@ -205,6 +209,33 @@ def _union_cast(cls: type[TypedColumn], s: nw.Series, union: CfUnion) -> nw.Seri raise error +def _fmt_optional_string(s: str | None) -> str | None: + return f'"{s}"' if s is not None else None + + +def _fmt_dict_string(d: dict) -> str: + items = [] + for k, v in d.items(): + # Format the key with double quotes + key = f'"{k}"' + + # Check if the value is a dictionary and recurse + if isinstance(v, dict): + value = _fmt_dict_string(v) + # Check if the value is a string and remove quotes + elif isinstance(v, str): + value = v + # Otherwise, use json.dumps for standard formatting (e.g., numbers, lists) + else: + raise ValueError( + "All keys must be strings and all values must either be strings or dicts" + ) + + items.append(f"{key}: {value}") + + return "{" + ", ".join(items) + "}" + + class Int8(nw.Int8, _BoundedDType, TypedColumn): _min = -128 _max = 127 @@ -245,6 +276,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Int8" + class Int16(nw.Int16, _BoundedDType, TypedColumn): _min = -32_768 @@ -292,6 +327,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Int16" + class Int32(nw.Int32, _BoundedDType, TypedColumn): _min = -2_147_483_648 @@ -339,6 +378,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Int32" + class Int64(nw.Int64, TypedColumn, _BoundedDType): _min = -9_223_372_036_854_775_808 @@ -386,6 +429,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Int64" + class Int128(nw.Int128, TypedColumn, _BoundedDType): _min = -170141183460469231731687303715884105728 @@ -442,6 +489,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Int128" + class UInt8(nw.UInt8, TypedColumn, _BoundedDType): _min = 0 @@ -498,6 +549,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}UInt8" + class UInt16(nw.UInt16, TypedColumn, _BoundedDType): _min = 0 @@ -552,6 +607,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}UInt16" + class UInt32(nw.UInt32, TypedColumn, _BoundedDType): _min = 0 @@ -603,6 +662,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}UInt32" + class UInt64(nw.UInt64, TypedColumn, _BoundedDType): _min = 0 @@ -661,6 +724,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}UInt64" + class UInt128(nw.UInt128, TypedColumn, _BoundedDType): _min = 0 @@ -716,6 +783,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}UInt128" + class Float32(nw.Float32, TypedColumn, _BoundedDType): # min and max represent min/max representible int that can be converted without loss @@ -778,6 +849,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Float32" + class Float64(nw.Float64, TypedColumn, _BoundedDType): _min = -9_007_199_254_740_991 @@ -837,6 +912,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Float64" + class Decimal(nw.Decimal, TypedColumn, _DType): def __init__( @@ -868,6 +947,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Decimal, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Decimal" + class Binary(nw.Binary, TypedColumn, _DType): def __init__( @@ -899,6 +982,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Binary, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Binary" + class Boolean(nw.Boolean, TypedColumn, _DType): def __init__( @@ -930,6 +1017,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Boolean, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Boolean" + class Categorical(nw.Categorical, TypedColumn, _DType): def __init__( @@ -961,6 +1052,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Categorical, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Categorical" + class Enum(nw.Enum, TypedColumn, _DType): def __init__( @@ -992,6 +1087,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Enum, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Enum" + class Date(nw.Date, TypedColumn, _DType): def __init__( @@ -1023,6 +1122,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Date, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Date" + class Datetime(nw.Datetime, TypedColumn, _DType): def __init__( @@ -1047,6 +1150,7 @@ def __init__( ) self.to_narwhals = self.__to_narwhals # type: ignore + self._to_repr = self.__to_repr # type: ignore @staticmethod def to_narwhals(): @@ -1072,6 +1176,13 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Datetime, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Datetime" + + def __to_repr(self, prefix: str = "") -> str: + return f"{prefix}Datetime(time_unit={_fmt_optional_string(self.time_unit)}, time_zone={_fmt_optional_string(self.time_zone)})" + class Duration(nw.Duration, TypedColumn, _DType): def __init__( @@ -1095,6 +1206,7 @@ def __init__( ) self.to_narwhals = self.__to_narwhals # type: ignore + self._to_repr = self.__to_repr # type: ignore @staticmethod def to_narwhals(): @@ -1116,6 +1228,13 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Duration, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Duration" + + def __to_repr(self, prefix: str = "") -> str: + return f"{prefix}Duration(time_unit={_fmt_optional_string(self.time_unit)})" + class String(nw.String, TypedColumn, _DType): def __init__( @@ -1147,6 +1266,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(String, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}String" + class Object(nw.Object, TypedColumn, _DType): def __init__( @@ -1178,6 +1301,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Object, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Object" + class Unknown(nw.Unknown, TypedColumn, _DType): def __init__( @@ -1209,6 +1336,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Unknown, s, to_dtype) return _checked_cast(s, to_dtype) + @staticmethod + def _to_repr(prefix: str = "") -> str: + return f"{prefix}Unknown" + class Array(nw.Array, TypedColumn, _DType): def __init__( @@ -1249,6 +1380,9 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Array, s, to_dtype) return _checked_cast(s, to_dtype) + def _to_repr(self, prefix="") -> str: # type: ignore + return f"{prefix}Array({self.inner._to_repr(prefix)})" + class List(nw.List, TypedColumn, _DType): def __init__( @@ -1286,6 +1420,9 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(List, s, to_dtype) return _checked_cast(s, to_dtype) + def _to_repr(self, prefix="") -> str: # type: ignore + return f"{prefix}List({self.inner._to_repr(prefix)})" + class _Field: name: str @@ -1340,6 +1477,13 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series: return _union_cast(Struct, s, to_dtype) return _checked_cast(s, to_dtype) + def _to_repr(self, prefix="") -> str: # type: ignore + dct = {} + for field in self.fields: + dct[field.name] = field.dtype._to_repr(prefix) + + return f"{prefix}Struct({_fmt_dict_string(dct)})" + class CfUnion: """Union type. diff --git a/src/checkedframe/_schema_generation.py b/src/checkedframe/_schema_generation.py index 46a61b9..e621efb 100644 --- a/src/checkedframe/_schema_generation.py +++ b/src/checkedframe/_schema_generation.py @@ -147,10 +147,10 @@ class MySchema(cf.Schema): kwargs_to_show.append(f"{k}={v}") display_kwargs = ", ".join(kwargs_to_show) - display_dtype = str(cf_dtype).replace("(", f"({import_alias}") + display_dtype = cf_dtype._to_repr(import_alias) columns.append( - f" {sanitized_col} = {import_alias}{display_dtype}({display_kwargs})".replace( + f" {sanitized_col} = {display_dtype}({display_kwargs})".replace( ")(", ", " if len(kwargs_to_show) > 0 else "" ) ) diff --git a/tests/test_schema_parsing.py b/tests/test_schema_parsing.py index 694cfeb..c64d5e1 100644 --- a/tests/test_schema_parsing.py +++ b/tests/test_schema_parsing.py @@ -1,4 +1,4 @@ -import tempfile +import datetime import polars as pl import pytest @@ -103,17 +103,27 @@ def test_schema_generation(): "class AASchema(cf.Schema):\n" ' column_0 = cf.Float64(name="reason code", nullable=True, allow_nan=True)\n' " y = cf.List(cf.Int64, nullable=True)\n" - " z = cf.List(cf.List(cf.Int64), nullable=True)" + " z = cf.List(cf.List(cf.Int64), nullable=True)\n" + ' datetime = cf.Datetime(time_unit="us", time_zone=None)\n' + ' struct = cf.Struct({"a": cf.Int64, "b": cf.Struct({"d": cf.List(cf.Int64)})}, nullable=True)' + ) + + df = pl.DataFrame( + { + "reason code": [1.0, float("nan"), None], + "y": [[1], [2], None], + "z": [[[1]], None, [[3]]], + "datetime": datetime.datetime(2016, 1, 22), + "struct": [ + {"a": 1, "b": {"d": [1, 2, 3]}}, + {"a": 2, "b": {"d": [1, 2, 3]}}, + None, + ], + } ) schema_repr = cf.generate_schema_repr( - pl.DataFrame( - { - "reason code": [1.0, float("nan"), None], - "y": [[1], [2], None], - "z": [[[1]], None, [[3]]], - } - ), + df, class_name="AASchema", ) @@ -127,7 +137,9 @@ def test_schema_generation_lazy(): "class AASchema(cf.Schema):\n" ' column_0 = cf.Float64(name="reason code")\n' " y = cf.List(cf.Int64)\n" - " z = cf.List(cf.List(cf.Int64))" + " z = cf.List(cf.List(cf.Int64))\n" + ' datetime = cf.Datetime(time_unit="us", time_zone=None)\n' + ' struct = cf.Struct({"a": cf.Int64, "b": cf.Struct({"d": cf.List(cf.Int64)})})' ) df = pl.DataFrame( @@ -135,6 +147,12 @@ def test_schema_generation_lazy(): "reason code": [1.0, float("nan"), None], "y": [[1], [2], None], "z": [[[1]], None, [[3]]], + "datetime": datetime.datetime(2016, 1, 22), + "struct": [ + {"a": 1, "b": {"d": [1, 2, 3]}}, + {"a": 2, "b": {"d": [1, 2, 3]}}, + None, + ], } )