Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions src/checkedframe/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def to_narwhals() -> NarwhalsDType | type[NarwhalsDType]: ...
@abstractmethod
def _safe_cast(s: nw.Series, to_dtype: _DType) -> nw.Series: ...

@staticmethod
@abstractmethod
def _to_repr(prefix: str = "") -> str: ...


class _BoundedDType(_DType):
_min: int | float
Expand Down Expand Up @@ -205,6 +209,33 @@ def _union_cast(cls: type[TypedColumn], s: nw.Series, union: CfUnion) -> nw.Seri
raise error


def _fmt_optional_string(s: str | None) -> str | None:
return f'"{s}"' if s is not None else None


def _fmt_dict_string(d: dict) -> str:
items = []
for k, v in d.items():
# Format the key with double quotes
key = f'"{k}"'

# Check if the value is a dictionary and recurse
if isinstance(v, dict):
value = _fmt_dict_string(v)
# Check if the value is a string and remove quotes
elif isinstance(v, str):
value = v
# Otherwise, use json.dumps for standard formatting (e.g., numbers, lists)
else:
raise ValueError(
"All keys must be strings and all values must either be strings or dicts"
)

items.append(f"{key}: {value}")

return "{" + ", ".join(items) + "}"


class Int8(nw.Int8, _BoundedDType, TypedColumn):
_min = -128
_max = 127
Expand Down Expand Up @@ -245,6 +276,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Int8"


class Int16(nw.Int16, _BoundedDType, TypedColumn):
_min = -32_768
Expand Down Expand Up @@ -292,6 +327,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Int16"


class Int32(nw.Int32, _BoundedDType, TypedColumn):
_min = -2_147_483_648
Expand Down Expand Up @@ -339,6 +378,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Int32"


class Int64(nw.Int64, TypedColumn, _BoundedDType):
_min = -9_223_372_036_854_775_808
Expand Down Expand Up @@ -386,6 +429,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Int64"


class Int128(nw.Int128, TypedColumn, _BoundedDType):
_min = -170141183460469231731687303715884105728
Expand Down Expand Up @@ -442,6 +489,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Int128"


class UInt8(nw.UInt8, TypedColumn, _BoundedDType):
_min = 0
Expand Down Expand Up @@ -498,6 +549,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}UInt8"


class UInt16(nw.UInt16, TypedColumn, _BoundedDType):
_min = 0
Expand Down Expand Up @@ -552,6 +607,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}UInt16"


class UInt32(nw.UInt32, TypedColumn, _BoundedDType):
_min = 0
Expand Down Expand Up @@ -603,6 +662,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}UInt32"


class UInt64(nw.UInt64, TypedColumn, _BoundedDType):
_min = 0
Expand Down Expand Up @@ -661,6 +724,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}UInt64"


class UInt128(nw.UInt128, TypedColumn, _BoundedDType):
_min = 0
Expand Down Expand Up @@ -716,6 +783,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}UInt128"


class Float32(nw.Float32, TypedColumn, _BoundedDType):
# min and max represent min/max representible int that can be converted without loss
Expand Down Expand Up @@ -778,6 +849,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Float32"


class Float64(nw.Float64, TypedColumn, _BoundedDType):
_min = -9_007_199_254_740_991
Expand Down Expand Up @@ -837,6 +912,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:

return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Float64"


class Decimal(nw.Decimal, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -868,6 +947,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Decimal, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Decimal"


class Binary(nw.Binary, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -899,6 +982,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Binary, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Binary"


class Boolean(nw.Boolean, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -930,6 +1017,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Boolean, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Boolean"


class Categorical(nw.Categorical, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -961,6 +1052,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Categorical, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Categorical"


class Enum(nw.Enum, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -992,6 +1087,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Enum, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Enum"


class Date(nw.Date, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -1023,6 +1122,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Date, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Date"


class Datetime(nw.Datetime, TypedColumn, _DType):
def __init__(
Expand All @@ -1047,6 +1150,7 @@ def __init__(
)

self.to_narwhals = self.__to_narwhals # type: ignore
self._to_repr = self.__to_repr # type: ignore

@staticmethod
def to_narwhals():
Expand All @@ -1072,6 +1176,13 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Datetime, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Datetime"

def __to_repr(self, prefix: str = "") -> str:
return f"{prefix}Datetime(time_unit={_fmt_optional_string(self.time_unit)}, time_zone={_fmt_optional_string(self.time_zone)})"


class Duration(nw.Duration, TypedColumn, _DType):
def __init__(
Expand All @@ -1095,6 +1206,7 @@ def __init__(
)

self.to_narwhals = self.__to_narwhals # type: ignore
self._to_repr = self.__to_repr # type: ignore

@staticmethod
def to_narwhals():
Expand All @@ -1116,6 +1228,13 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Duration, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Duration"

def __to_repr(self, prefix: str = "") -> str:
return f"{prefix}Duration(time_unit={_fmt_optional_string(self.time_unit)})"


class String(nw.String, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -1147,6 +1266,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(String, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}String"


class Object(nw.Object, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -1178,6 +1301,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Object, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Object"


class Unknown(nw.Unknown, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -1209,6 +1336,10 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Unknown, s, to_dtype)
return _checked_cast(s, to_dtype)

@staticmethod
def _to_repr(prefix: str = "") -> str:
return f"{prefix}Unknown"


class Array(nw.Array, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -1249,6 +1380,9 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Array, s, to_dtype)
return _checked_cast(s, to_dtype)

def _to_repr(self, prefix="") -> str: # type: ignore
return f"{prefix}Array({self.inner._to_repr(prefix)})"


class List(nw.List, TypedColumn, _DType):
def __init__(
Expand Down Expand Up @@ -1286,6 +1420,9 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(List, s, to_dtype)
return _checked_cast(s, to_dtype)

def _to_repr(self, prefix="") -> str: # type: ignore
return f"{prefix}List({self.inner._to_repr(prefix)})"


class _Field:
name: str
Expand Down Expand Up @@ -1340,6 +1477,13 @@ def _safe_cast(s: nw.Series, to_dtype: _DType | CfUnion) -> nw.Series:
return _union_cast(Struct, s, to_dtype)
return _checked_cast(s, to_dtype)

def _to_repr(self, prefix="") -> str: # type: ignore
dct = {}
for field in self.fields:
dct[field.name] = field.dtype._to_repr(prefix)

return f"{prefix}Struct({_fmt_dict_string(dct)})"


class CfUnion:
"""Union type.
Expand Down
4 changes: 2 additions & 2 deletions src/checkedframe/_schema_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ class MySchema(cf.Schema):
kwargs_to_show.append(f"{k}={v}")

display_kwargs = ", ".join(kwargs_to_show)
display_dtype = str(cf_dtype).replace("(", f"({import_alias}")
display_dtype = cf_dtype._to_repr(import_alias)

columns.append(
f" {sanitized_col} = {import_alias}{display_dtype}({display_kwargs})".replace(
f" {sanitized_col} = {display_dtype}({display_kwargs})".replace(
")(", ", " if len(kwargs_to_show) > 0 else ""
)
)
Expand Down
Loading
Loading