Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions python/tvm_ffi/dataclasses/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,13 @@ def type_info_to_cls(
def _add_method(name: str, func: Callable[..., Any]) -> None:
if name == "__ffi_init__":
name = "__c_ffi_init__"
if name in attrs: # already defined
return
# Allow overriding methods (including from base classes like Object.__repr__)
# by always adding to attrs, which will be used when creating the new class
func.__module__ = cls.__module__
func.__name__ = name
func.__qualname__ = f"{cls.__qualname__}.{name}"
func.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`"
attrs[name] = func
setattr(cls, name, func)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why we removed this line?


for name, method_impl in methods.items():
if method_impl is not None:
Expand Down Expand Up @@ -98,19 +97,63 @@ def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
type_field.dataclass_field = rhs


def _get_all_fields(type_info: TypeInfo) -> list[TypeField]:
"""Collect all fields from the type hierarchy, from parents to children."""
fields: list[TypeField] = []
cur_type_info: TypeInfo | None = type_info
while cur_type_info is not None:
fields.extend(reversed(cur_type_info.fields))
cur_type_info = cur_type_info.parent_type_info
fields.reverse()
return fields


def method_repr(type_cls: type, type_info: TypeInfo) -> Callable[..., str]:
"""Generate a ``__repr__`` method for the dataclass.

The generated representation includes all fields with ``repr=True`` in
the format ``ClassName(field1=value1, field2=value2, ...)``.
"""
# Step 0. Collect all fields from the type hierarchy
fields = _get_all_fields(type_info)

# Step 1. Filter fields that should appear in repr
repr_fields: list[str] = []
for field in fields:
assert field.name is not None
assert field.dataclass_field is not None
if field.dataclass_field.repr:
repr_fields.append(field.name)

# Step 2. Generate the repr method
if not repr_fields:
# No fields to show, return a simple class name representation
body_lines = [f"return f'{type_cls.__name__}()'"]
else:
# Build field representations
fields_str = ", ".join(f"{field_name}={{self.{field_name}!r}}" for field_name in repr_fields)
body_lines = [f"return f'{type_cls.__name__}({fields_str})'"]

source_lines = ["def __repr__(self) -> str:"]
source_lines.extend(f" {line}" for line in body_lines)
source = "\n".join(source_lines)

# Note: Code generation in this case is guaranteed to be safe,
# because the generated code does not contain any untrusted input.
namespace: dict[str, Any] = {}
exec(source, {}, namespace)
__repr__ = namespace["__repr__"]
return __repr__


def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Generate an ``__init__`` that forwards to the FFI constructor.

The generated initializer has a proper Python signature built from the
reflected field list, supporting default values and ``__post_init__``.
"""
# Step 0. Collect all fields from the type hierarchy
fields: list[TypeField] = []
cur_type_info: TypeInfo | None = type_info
while cur_type_info is not None:
fields.extend(reversed(cur_type_info.fields))
cur_type_info = cur_type_info.parent_type_info
fields.reverse()
fields = _get_all_fields(type_info)
# sanity check
for type_method in type_info.methods:
if type_method.name == "__ffi_init__":
Expand Down
12 changes: 9 additions & 3 deletions python/tvm_ffi/dataclasses/c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

@dataclass_transform(field_specifiers=(field,))
def c_class(
type_key: str, init: bool = True
type_key: str, init: bool = True, repr: bool = True
) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]: # noqa: UP006
"""(Experimental) Create a dataclass-like proxy for a C++ class registered with TVM FFI.

Expand Down Expand Up @@ -71,6 +71,10 @@ def c_class(
signature. The generated initializer calls the C++ ``__init__``
function registered with ``ObjectDef`` and invokes ``__post_init__`` if
it exists on the Python class.
repr
If ``True`` and the Python class does not define ``__repr__``, a
representation method is auto-generated that includes all fields with
``repr=True``.

Returns
-------
Expand Down Expand Up @@ -118,8 +122,9 @@ class MyClass:
"""

def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # noqa: UP006
nonlocal init
nonlocal init, repr
init = init and "__init__" not in super_type_cls.__dict__
repr = repr and "__repr__" not in super_type_cls.__dict__
# Step 1. Retrieve `type_info` from registry
type_info: TypeInfo = _lookup_or_register_type_info_from_type_key(type_key)
assert type_info.parent_type_info is not None
Expand All @@ -129,10 +134,11 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no
_utils.fill_dataclass_field(super_type_cls, type_field)
# Step 3. Create the proxy class with the fields as properties
fn_init = _utils.method_init(super_type_cls, type_info) if init else None
fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else None
type_cls: Type[_InputClsType] = _utils.type_info_to_cls( # noqa: UP006
type_info=type_info,
cls=super_type_cls,
methods={"__init__": fn_init},
methods={"__init__": fn_init, "__repr__": fn_repr},
)
_set_type_cls(type_info, type_cls)
return type_cls
Expand Down
12 changes: 10 additions & 2 deletions python/tvm_ffi/dataclasses/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,29 @@ class Field:
way the decorator understands.
"""

__slots__ = ("default_factory", "init", "name")
__slots__ = ("default_factory", "init", "name", "repr")

def __init__(
self,
*,
name: str | None = None,
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
init: bool = True,
repr: bool = True,
) -> None:
"""Do not call directly; use :func:`field` instead."""
self.name = name
self.default_factory = default_factory
self.init = init
self.repr = repr


def field(
*,
default: _FieldValue | _MISSING_TYPE = MISSING, # type: ignore[assignment]
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, # type: ignore[assignment]
init: bool = True,
repr: bool = True,
) -> _FieldValue:
"""(Experimental) Declare a dataclass-style field on a :func:`c_class` proxy.

Expand All @@ -78,6 +81,9 @@ def field(
init
If ``True`` the field is included in the generated ``__init__``.
If ``False`` the field is omitted from input arguments of ``__init__``.
repr
If ``True`` the field is included in the generated ``__repr__``.
If ``False`` the field is omitted from the ``__repr__`` output.

Note
----
Expand Down Expand Up @@ -123,9 +129,11 @@ class PyBase:
raise ValueError("Cannot specify both `default` and `default_factory`")
if not isinstance(init, bool):
raise TypeError("`init` must be a bool")
if not isinstance(repr, bool):
raise TypeError("`repr` must be a bool")
if default is not MISSING:
default_factory = _make_default_factory(default)
ret = Field(default_factory=default_factory, init=init)
ret = Field(default_factory=default_factory, init=init, repr=repr)
return cast(_FieldValue, ret)


Expand Down
35 changes: 35 additions & 0 deletions tests/python/test_dataclasses_c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,38 @@ def test_cxx_class_init_subset_positional() -> None:
assert obj.optional_field == -1
obj.optional_field = 11
assert obj.optional_field == 11


def test_cxx_class_repr() -> None:
obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0, v_f32=8.0)
repr_str = repr(obj)
assert "_TestCxxClassDerived" in repr_str
if "__repr__" in _TestCxxClassDerived.__dict__:
assert "v_i64=123" in repr_str
assert "v_i32=456" in repr_str
assert "v_f64=4.0" in repr_str
assert "v_f32=8.0" in repr_str


def test_cxx_class_repr_default() -> None:
obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0)
repr_str = repr(obj)
assert "_TestCxxClassDerived" in repr_str
if "__repr__" in _TestCxxClassDerived.__dict__:
assert "v_i64=123" in repr_str
assert "v_i32=456" in repr_str
assert "v_f64=4.0" in repr_str
assert "v_f32=8.0" in repr_str


def test_cxx_class_repr_derived_derived() -> None:
obj = _TestCxxClassDerivedDerived(
v_i64=123, v_i32=456, v_f64=4.0, v_f32=8.0, v_str="hello", v_bool=True
)
repr_str = repr(obj)
assert "_TestCxxClassDerivedDerived" in repr_str
if "__repr__" in _TestCxxClassDerivedDerived.__dict__:
assert "v_i64=123" in repr_str
assert "v_i32=456" in repr_str
assert "v_str='hello'" in repr_str or "v_str=\"hello\"" in repr_str
assert "v_bool=True" in repr_str
Loading