diff --git a/python/tvm_ffi/dataclasses/_utils.py b/python/tvm_ffi/dataclasses/_utils.py index bd647a60..5ed4e963 100644 --- a/python/tvm_ffi/dataclasses/_utils.py +++ b/python/tvm_ffi/dataclasses/_utils.py @@ -58,14 +58,13 @@ def type_info_to_cls( def _add_method(name: str, func: Callable[..., Any]) -> None: if name == "__ffi_init__": name = "__c_ffi_init__" - if name in attrs: # already defined - return + # Allow overriding methods (including from base classes like Object.__repr__) + # by always adding to attrs, which will be used when creating the new class func.__module__ = cls.__module__ func.__name__ = name func.__qualname__ = f"{cls.__qualname__}.{name}" func.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`" attrs[name] = func - setattr(cls, name, func) for name, method_impl in methods.items(): if method_impl is not None: @@ -98,6 +97,57 @@ def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None: type_field.dataclass_field = rhs +def _get_all_fields(type_info: TypeInfo) -> list[TypeField]: + """Collect all fields from the type hierarchy, from parents to children.""" + fields: list[TypeField] = [] + cur_type_info: TypeInfo | None = type_info + while cur_type_info is not None: + fields.extend(reversed(cur_type_info.fields)) + cur_type_info = cur_type_info.parent_type_info + fields.reverse() + return fields + + +def method_repr(type_cls: type, type_info: TypeInfo) -> Callable[..., str]: + """Generate a ``__repr__`` method for the dataclass. + + The generated representation includes all fields with ``repr=True`` in + the format ``ClassName(field1=value1, field2=value2, ...)``. + """ + # Step 0. Collect all fields from the type hierarchy + fields = _get_all_fields(type_info) + + # Step 1. Filter fields that should appear in repr + repr_fields: list[str] = [] + for field in fields: + assert field.name is not None + assert field.dataclass_field is not None + if field.dataclass_field.repr: + repr_fields.append(field.name) + + # Step 2. Generate the repr method + if not repr_fields: + # No fields to show, return a simple class name representation + body_lines = [f"return f'{type_cls.__name__}()'"] + else: + # Build field representations + fields_str = ", ".join( + f"{field_name}={{self.{field_name}!r}}" for field_name in repr_fields + ) + body_lines = [f"return f'{type_cls.__name__}({fields_str})'"] + + source_lines = ["def __repr__(self) -> str:"] + source_lines.extend(f" {line}" for line in body_lines) + source = "\n".join(source_lines) + + # Note: Code generation in this case is guaranteed to be safe, + # because the generated code does not contain any untrusted input. + namespace: dict[str, Any] = {} + exec(source, {}, namespace) + __repr__ = namespace["__repr__"] + return __repr__ + + def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]: """Generate an ``__init__`` that forwards to the FFI constructor. @@ -105,12 +155,7 @@ def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]: reflected field list, supporting default values and ``__post_init__``. """ # Step 0. Collect all fields from the type hierarchy - fields: list[TypeField] = [] - cur_type_info: TypeInfo | None = type_info - while cur_type_info is not None: - fields.extend(reversed(cur_type_info.fields)) - cur_type_info = cur_type_info.parent_type_info - fields.reverse() + fields = _get_all_fields(type_info) # sanity check for type_method in type_info.methods: if type_method.name == "__ffi_init__": diff --git a/python/tvm_ffi/dataclasses/c_class.py b/python/tvm_ffi/dataclasses/c_class.py index 65d7c732..8171b1b5 100644 --- a/python/tvm_ffi/dataclasses/c_class.py +++ b/python/tvm_ffi/dataclasses/c_class.py @@ -41,7 +41,7 @@ @dataclass_transform(field_specifiers=(field,)) def c_class( - type_key: str, init: bool = True + type_key: str, init: bool = True, repr: bool = True ) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]: # noqa: UP006 """(Experimental) Create a dataclass-like proxy for a C++ class registered with TVM FFI. @@ -71,6 +71,10 @@ def c_class( signature. The generated initializer calls the C++ ``__init__`` function registered with ``ObjectDef`` and invokes ``__post_init__`` if it exists on the Python class. + repr + If ``True`` and the Python class does not define ``__repr__``, a + representation method is auto-generated that includes all fields with + ``repr=True``. Returns ------- @@ -118,8 +122,9 @@ class MyClass: """ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # noqa: UP006 - nonlocal init + nonlocal init, repr init = init and "__init__" not in super_type_cls.__dict__ + repr = repr and "__repr__" not in super_type_cls.__dict__ # Step 1. Retrieve `type_info` from registry type_info: TypeInfo = _lookup_or_register_type_info_from_type_key(type_key) assert type_info.parent_type_info is not None @@ -129,10 +134,11 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no _utils.fill_dataclass_field(super_type_cls, type_field) # Step 3. Create the proxy class with the fields as properties fn_init = _utils.method_init(super_type_cls, type_info) if init else None + fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else None type_cls: Type[_InputClsType] = _utils.type_info_to_cls( # noqa: UP006 type_info=type_info, cls=super_type_cls, - methods={"__init__": fn_init}, + methods={"__init__": fn_init, "__repr__": fn_repr}, ) _set_type_cls(type_info, type_cls) return type_cls diff --git a/python/tvm_ffi/dataclasses/field.py b/python/tvm_ffi/dataclasses/field.py index d10612cf..d0e27b1f 100644 --- a/python/tvm_ffi/dataclasses/field.py +++ b/python/tvm_ffi/dataclasses/field.py @@ -37,7 +37,7 @@ class Field: way the decorator understands. """ - __slots__ = ("default_factory", "init", "name") + __slots__ = ("default_factory", "init", "name", "repr") def __init__( self, @@ -45,11 +45,13 @@ def __init__( name: str | None = None, default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, init: bool = True, + repr: bool = True, ) -> None: """Do not call directly; use :func:`field` instead.""" self.name = name self.default_factory = default_factory self.init = init + self.repr = repr def field( @@ -57,6 +59,7 @@ def field( default: _FieldValue | _MISSING_TYPE = MISSING, # type: ignore[assignment] default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, # type: ignore[assignment] init: bool = True, + repr: bool = True, ) -> _FieldValue: """(Experimental) Declare a dataclass-style field on a :func:`c_class` proxy. @@ -78,6 +81,9 @@ def field( init If ``True`` the field is included in the generated ``__init__``. If ``False`` the field is omitted from input arguments of ``__init__``. + repr + If ``True`` the field is included in the generated ``__repr__``. + If ``False`` the field is omitted from the ``__repr__`` output. Note ---- @@ -123,9 +129,11 @@ class PyBase: raise ValueError("Cannot specify both `default` and `default_factory`") if not isinstance(init, bool): raise TypeError("`init` must be a bool") + if not isinstance(repr, bool): + raise TypeError("`repr` must be a bool") if default is not MISSING: default_factory = _make_default_factory(default) - ret = Field(default_factory=default_factory, init=init) + ret = Field(default_factory=default_factory, init=init, repr=repr) return cast(_FieldValue, ret) diff --git a/tests/python/test_dataclasses_c_class.py b/tests/python/test_dataclasses_c_class.py index 676bbf59..5361cb64 100644 --- a/tests/python/test_dataclasses_c_class.py +++ b/tests/python/test_dataclasses_c_class.py @@ -94,3 +94,38 @@ def test_cxx_class_init_subset_positional() -> None: assert obj.optional_field == -1 obj.optional_field = 11 assert obj.optional_field == 11 + + +def test_cxx_class_repr() -> None: + obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0, v_f32=8.0) + repr_str = repr(obj) + assert "_TestCxxClassDerived" in repr_str + if "__repr__" in _TestCxxClassDerived.__dict__: + assert "v_i64=123" in repr_str + assert "v_i32=456" in repr_str + assert "v_f64=4.0" in repr_str + assert "v_f32=8.0" in repr_str + + +def test_cxx_class_repr_default() -> None: + obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0) + repr_str = repr(obj) + assert "_TestCxxClassDerived" in repr_str + if "__repr__" in _TestCxxClassDerived.__dict__: + assert "v_i64=123" in repr_str + assert "v_i32=456" in repr_str + assert "v_f64=4.0" in repr_str + assert "v_f32=8.0" in repr_str + + +def test_cxx_class_repr_derived_derived() -> None: + obj = _TestCxxClassDerivedDerived( + v_i64=123, v_i32=456, v_f64=4.0, v_f32=8.0, v_str="hello", v_bool=True + ) + repr_str = repr(obj) + assert "_TestCxxClassDerivedDerived" in repr_str + if "__repr__" in _TestCxxClassDerivedDerived.__dict__: + assert "v_i64=123" in repr_str + assert "v_i32=456" in repr_str + assert "v_str='hello'" in repr_str or 'v_str="hello"' in repr_str + assert "v_bool=True" in repr_str