diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java index 48d071a..6376728 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java @@ -159,11 +159,19 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp List nonObjectInterfaceImplementors = javaInterfaceImplementorSet.stream() .filter(implementor -> !Object.class.equals(implementor.getInterfaceClass())) - .collect(Collectors.toList()); - String[] interfaces = new String[nonObjectInterfaceImplementors.size()]; + .toList(); + + String[] interfaces = new String[nonObjectInterfaceImplementors.size() + pythonCompiledClass.markerInterfaces.size()]; for (int i = 0; i < nonObjectInterfaceImplementors.size(); i++) { interfaces[i] = Type.getInternalName(nonObjectInterfaceImplementors.get(i).getInterfaceClass()); } + for (int i = 0; i < pythonCompiledClass.markerInterfaces.size(); i++) { + var markerInterface = pythonCompiledClass.markerInterfaces.get(i); + if (!markerInterface.isInterface()) { + throw new IllegalArgumentException("%s is not an interface".formatted(markerInterface)); + } + interfaces[i + nonObjectInterfaceImplementors.size()] = Type.getInternalName(markerInterface); + } ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java index 3b0a6aa..0fbc9b5 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java @@ -36,6 +36,11 @@ public class PythonCompiledClass { */ public Map typeAnnotations; + /** + * Marker interfaces the class implement + */ + public List> markerInterfaces; + /** * The binary type of this PythonCompiledClass; * typically {@link CPythonType}. Used when methods diff --git a/jpyinterpreter/src/main/python/__init__.py b/jpyinterpreter/src/main/python/__init__.py index d9a825a..7a7f9e5 100644 --- a/jpyinterpreter/src/main/python/__init__.py +++ b/jpyinterpreter/src/main/python/__init__.py @@ -2,7 +2,7 @@ This module acts as an interface to the Python bytecode to Java bytecode interpreter """ from .jvm_setup import init, set_class_output_directory -from .annotations import JavaAnnotation, add_class_annotation +from .annotations import JavaAnnotation, add_class_annotation, add_marker_interface from .conversions import (convert_to_java_python_like_object, unwrap_python_like_object, update_python_object_from_java, is_c_native) from .translator import (translate_python_bytecode_to_java_bytecode, diff --git a/jpyinterpreter/src/main/python/annotations.py b/jpyinterpreter/src/main/python/annotations.py index b3d9a35..61406a4 100644 --- a/jpyinterpreter/src/main/python/annotations.py +++ b/jpyinterpreter/src/main/python/annotations.py @@ -30,6 +30,19 @@ def decorator(_cls: Type[T]) -> Type[T]: return decorator +def add_marker_interface(marker_interface: JClass | str, /) -> Callable[[Type[T]], Type[T]]: + def decorator(_cls: Type[T]) -> Type[T]: + from .translator import type_to_compiled_java_class, type_to_marker_interfaces + if _cls in type_to_compiled_java_class: + raise RuntimeError('Cannot add a marker interface after a class been compiled.') + marker_interfaces = type_to_marker_interfaces.get(_cls, []) + marker_interfaces.append(marker_interface) + type_to_marker_interfaces[_cls] = marker_interfaces + return _cls + + return decorator + + def copy_type_annotations(hinted_object, default_args, vargs_name, kwargs_name): from java.util import HashMap, Collections from ai.timefold.jpyinterpreter import TypeHint diff --git a/jpyinterpreter/src/main/python/translator.py b/jpyinterpreter/src/main/python/translator.py index 90c399e..152f989 100644 --- a/jpyinterpreter/src/main/python/translator.py +++ b/jpyinterpreter/src/main/python/translator.py @@ -13,6 +13,7 @@ global_dict_to_key_set = dict() type_to_compiled_java_class = dict() type_to_annotations = dict() +type_to_marker_interfaces = dict() function_interface_pair_to_instance = dict() function_interface_pair_to_class = dict() @@ -629,9 +630,17 @@ def translate_python_class_to_java_class(python_class): python_compiled_class = PythonCompiledClass() python_compiled_class.annotations = ArrayList() + python_compiled_class.markerInterfaces = ArrayList() + for annotation in type_to_annotations.get(python_class, []): python_compiled_class.annotations.add(convert_java_annotation(annotation)) + for marker_interface in type_to_marker_interfaces.get(python_class, []): + if isinstance(marker_interface, str): + marker_interface = JClass(marker_interface) + + python_compiled_class.markerInterfaces.add(marker_interface) + python_compiled_class.binaryType = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class, convert=True)) python_compiled_class.module = python_class.__module__ diff --git a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java index f6b94d9..d8d1099 100644 --- a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java +++ b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java @@ -2,6 +2,7 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.function.Function; @@ -41,7 +42,8 @@ public void testPythonClassTranslation() throws ClassNotFoundException, NoSuchMe .op(ControlOpDescriptor.RETURN_VALUE) .build(); - compiledClass.annotations = List.of(); + compiledClass.annotations = Collections.emptyList(); + compiledClass.markerInterfaces = Collections.emptyList(); compiledClass.className = "MyClass"; compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); compiledClass.staticAttributeNameToObject = Map.of("type_variable", new PythonString("type_value")); @@ -94,7 +96,8 @@ public void testPythonClassComparable() throws ClassNotFoundException { PythonCompiledFunction comparisonFunction = getCompareFunction.apply(compareOp); PythonCompiledClass compiledClass = new PythonCompiledClass(); - compiledClass.annotations = List.of(); + compiledClass.annotations = Collections.emptyList(); + compiledClass.markerInterfaces = Collections.emptyList(); compiledClass.className = "MyClass"; compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); compiledClass.staticAttributeNameToObject = Map.of(); @@ -163,7 +166,8 @@ public void testPythonClassEqualsAndHashCode() throws ClassNotFoundException { .build(); PythonCompiledClass compiledClass = new PythonCompiledClass(); - compiledClass.annotations = List.of(); + compiledClass.annotations = Collections.emptyList(); + compiledClass.markerInterfaces = Collections.emptyList(); compiledClass.className = "MyClass"; compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); compiledClass.staticAttributeNameToObject = Map.of(); diff --git a/jpyinterpreter/tests/test_classes.py b/jpyinterpreter/tests/test_classes.py index 6c5e755..cde5fed 100644 --- a/jpyinterpreter/tests/test_classes.py +++ b/jpyinterpreter/tests/test_classes.py @@ -962,3 +962,27 @@ class A: translated_class = translate_python_class_to_java_class(A).getJavaClass() field_type = translated_class.getField('my_field').getGenericType() assert field_type.getActualTypeArguments()[0].getName() == PythonString.class_.getName() + + +def test_marker_interface(): + from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference + from jpyinterpreter import translate_python_class_to_java_class, add_marker_interface + + @add_marker_interface(OpaquePythonReference) + class A: + pass + + translated_class = translate_python_class_to_java_class(A).getJavaClass() + assert OpaquePythonReference.class_.isAssignableFrom(translated_class) + + +def test_marker_interface_string(): + from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference + from jpyinterpreter import translate_python_class_to_java_class, add_marker_interface + + @add_marker_interface('ai.timefold.jpyinterpreter.types.wrappers.OpaquePythonReference') + class A: + pass + + translated_class = translate_python_class_to_java_class(A).getJavaClass() + assert OpaquePythonReference.class_.isAssignableFrom(translated_class) diff --git a/tests/test_constraint_streams.py b/tests/test_constraint_streams.py index dae8d07..5aad0d6 100644 --- a/tests/test_constraint_streams.py +++ b/tests/test_constraint_streams.py @@ -517,6 +517,48 @@ def define_constraints(constraint_factory: ConstraintFactory): } +def test_custom_justifications(): + @dataclass(unsafe_hash=True) + class MyJustification(ConstraintJustification): + code: str + score: SimpleScore + + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .reward(SimpleScore.ONE, lambda e: e.value.number) + .justify_with(lambda e, score: MyJustification(e.code, score)) + .as_constraint('my_package', 'Maximize value') + ] + + score_manager = create_score_manager(define_constraints) + entity_a: Entity = Entity('A') + entity_b: Entity = Entity('B') + + value_1 = Value(1) + value_2 = Value(2) + value_3 = Value(3) + + entity_a.value = value_1 + entity_b.value = value_3 + + problem = Solution([entity_a, entity_b], [value_1, value_2, value_3]) + + justifications = score_manager.explain(problem).get_justification_list() + assert len(justifications) == 2 + assert MyJustification('A', SimpleScore.of(1)) in justifications + assert MyJustification('B', SimpleScore.of(3)) in justifications + + justifications = score_manager.explain(problem).get_justification_list(MyJustification) + assert len(justifications) == 2 + assert MyJustification('A', SimpleScore.of(1)) in justifications + assert MyJustification('B', SimpleScore.of(3)) in justifications + + justifications = score_manager.explain(problem).get_justification_list(DefaultConstraintJustification) + assert len(justifications) == 0 + + ignored_python_functions = { '_call_comparison_java_joiner', '__init__', @@ -534,23 +576,19 @@ def define_constraints(constraint_factory: ConstraintFactory): 'countLongBi', # Python has no concept of Long (everything a BigInteger) 'countLongQuad', 'countLongTri', - '_handler', # JPype handler field should be ignored - # Unimplemented penalize/reward/impact 'impactBigDecimal', - 'impactConfigurable', 'impactConfigurableBigDecimal', 'impactConfigurableLong', 'impactLong', 'penalizeBigDecimal', - 'penalizeConfigurable', 'penalizeConfigurableBigDecimal', 'penalizeConfigurableLong', 'penalizeLong', 'rewardBigDecimal', - 'rewardConfigurable', 'rewardConfigurableBigDecimal', 'rewardConfigurableLong', 'rewardLong', + '_handler', # JPype handler field should be ignored # These methods are deprecated 'from_', 'fromUnfiltered', diff --git a/timefold-solver-python-core/src/main/python/api/_solution_manager.py b/timefold-solver-python-core/src/main/python/api/_solution_manager.py index 2e226aa..bb43e2b 100644 --- a/timefold-solver-python-core/src/main/python/api/_solution_manager.py +++ b/timefold-solver-python-core/src/main/python/api/_solution_manager.py @@ -1,10 +1,10 @@ from ._solver_factory import SolverFactory from ._solver_manager import SolverManager from .._timefold_java_interop import get_class -from jpyinterpreter import unwrap_python_like_object +from jpyinterpreter import unwrap_python_like_object, add_marker_interface from dataclasses import dataclass -from typing import TypeVar, Generic, Union, TYPE_CHECKING, Any, cast, Optional +from typing import TypeVar, Generic, Union, TYPE_CHECKING, Any, cast, Optional, Type if TYPE_CHECKING: # These imports require a JVM to be running, so only import if type checking @@ -24,6 +24,7 @@ Solution_ = TypeVar('Solution_') ProblemId_ = TypeVar('ProblemId_') Score_ = TypeVar('Score_', bound='Score') +Justification_ = TypeVar('Justification_', bound='ConstraintJustification') @dataclass(frozen=True, unsafe_hash=True) @@ -100,8 +101,13 @@ def __hash__(self) -> int: return combined_hash +@add_marker_interface('ai.timefold.solver.core.api.score.stream.ConstraintJustification') +class ConstraintJustification: + pass + + @dataclass(frozen=True, eq=True) -class DefaultConstraintJustification: +class DefaultConstraintJustification(ConstraintJustification): facts: tuple[Any, ...] impact: Score_ @@ -127,7 +133,7 @@ def _map_constraint_match_set(constraint_match_set: set['_JavaConstraintMatch']) } -def _unwrap_justification(justification: Any) -> Any: +def _unwrap_justification(justification: Any) -> ConstraintJustification: from ai.timefold.solver.core.api.score.stream import ( DefaultConstraintJustification as _JavaDefaultConstraintJustification) if isinstance(justification, _JavaDefaultConstraintJustification): @@ -139,7 +145,7 @@ def _unwrap_justification(justification: Any) -> Any: return unwrap_python_like_object(justification) -def _unwrap_justification_list(justification_list: list[Any]) -> list[Any]: +def _unwrap_justification_list(justification_list: list[Any]) -> list[ConstraintJustification]: return [_unwrap_justification(justification) for justification in justification_list] @@ -163,7 +169,7 @@ def constraint_match_set(self) -> set[ConstraintMatch[Score_]]: def indicted_object(self) -> Any: return unwrap_python_like_object(self._delegate.getIndictedObject()) - def get_justification_list(self, justification_type=None) -> list[Any]: + def get_justification_list(self, justification_type: Type[Justification_] = None) -> list[Justification_]: if justification_type is None: justification_list = self._delegate.getJustificationList() else: @@ -250,7 +256,7 @@ def solution(self) -> Solution_: def summary(self) -> str: return self._delegate.getSummary() - def get_justification_list(self, justification_type=None) -> list[Any]: + def get_justification_list(self, justification_type: Type[Justification_] = None) -> list[Justification_]: if justification_type is None: justification_list = self._delegate.getJustificationList() else: @@ -275,7 +281,7 @@ def score(self) -> Score_: return self._delegate.score() @property - def justification(self) -> Any: + def justification(self) -> ConstraintJustification: return _unwrap_justification(self._delegate.justification()) @@ -343,5 +349,5 @@ def constraint_analyses(self) -> list[ConstraintAnalysis]: __all__ = ['SolutionManager', 'ScoreExplanation', 'ConstraintRef', 'ConstraintMatch', 'ConstraintMatchTotal', - 'DefaultConstraintJustification', 'Indictment', + 'ConstraintJustification', 'DefaultConstraintJustification', 'Indictment', 'ScoreAnalysis', 'ConstraintAnalysis', 'MatchAnalysis'] diff --git a/timefold-solver-python-core/src/main/python/constraint/_constraint_builder.py b/timefold-solver-python-core/src/main/python/constraint/_constraint_builder.py index df63e0c..2549ae8 100644 --- a/timefold-solver-python-core/src/main/python/constraint/_constraint_builder.py +++ b/timefold-solver-python-core/src/main/python/constraint/_constraint_builder.py @@ -1,4 +1,5 @@ from ._function_translator import function_cast +import timefold.solver.api as api from typing import TypeVar, Callable, Generic, Collection, Any, TYPE_CHECKING, Type if TYPE_CHECKING: @@ -31,10 +32,11 @@ def indict_with(self, indictment_function: Callable[[A], Collection]) -> 'UniCon return UniConstraintBuilder(self.delegate.indictWith( function_cast(indictment_function, self.a_type)), self.a_type) - def justify_with(self, justification_function: Callable[[A, ScoreType], Any]) -> \ + def justify_with(self, justification_function: Callable[[A, ScoreType], 'api.ConstraintJustification']) -> \ 'UniConstraintBuilder[A, ScoreType]': + from ai.timefold.solver.core.api.score import Score return UniConstraintBuilder(self.delegate.justifyWith( - function_cast(justification_function, self.a_type)), self.a_type) + function_cast(justification_function, self.a_type, Score)), self.a_type) def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint': if constraint_name is None: @@ -58,10 +60,11 @@ def indict_with(self, indictment_function: Callable[[A, B], Collection]) -> 'BiC return BiConstraintBuilder(self.delegate.indictWith( function_cast(indictment_function, self.a_type, self.b_type)), self.a_type, self.b_type) - def justify_with(self, justification_function: Callable[[A, B, ScoreType], Any]) -> \ + def justify_with(self, justification_function: Callable[[A, B, ScoreType], 'api.ConstraintJustification']) -> \ 'BiConstraintBuilder[A, B, ScoreType]': + from ai.timefold.solver.core.api.score import Score return BiConstraintBuilder(self.delegate.justifyWith( - function_cast(justification_function, self.a_type, self.b_type)), self.a_type, self.b_type) + function_cast(justification_function, self.a_type, self.b_type, Score)), self.a_type, self.b_type) def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint': if constraint_name is None: @@ -89,11 +92,12 @@ def indict_with(self, indictment_function: Callable[[A, B, C], Collection]) -> \ function_cast(indictment_function, self.a_type, self.b_type, self.c_type)), self.a_type, self.b_type, self.c_type) - def justify_with(self, justification_function: Callable[[A, B, C, ScoreType], Any]) -> \ + def justify_with(self, justification_function: Callable[[A, B, C, ScoreType], 'api.ConstraintJustification']) -> \ 'TriConstraintBuilder[A, B, C, ScoreType]': + from ai.timefold.solver.core.api.score import Score return TriConstraintBuilder(self.delegate.justifyWith( - function_cast(justification_function, self.a_type, self.b_type, self.c_type)), self.a_type, self.b_type, - self.c_type) + function_cast(justification_function, self.a_type, self.b_type, self.c_type, Score)), + self.a_type, self.b_type, self.c_type) def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint': if constraint_name is None: @@ -123,10 +127,11 @@ def indict_with(self, indictment_function: Callable[[A, B, C, D], Collection]) - function_cast(indictment_function, self.a_type, self.b_type, self.c_type, self.d_type)), self.a_type, self.b_type, self.c_type, self.d_type) - def justify_with(self, justification_function: Callable[[A, B, C, D, ScoreType], Any]) -> \ - 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + def justify_with(self, justification_function: Callable[[A, B, C, D, ScoreType], 'api.ConstraintJustification']) \ + -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + from ai.timefold.solver.core.api.score import Score return QuadConstraintBuilder(self.delegate.justifyWith( - function_cast(justification_function, self.a_type, self.b_type, self.c_type, self.d_type)), + function_cast(justification_function, self.a_type, self.b_type, self.c_type, self.d_type, Score)), self.a_type, self.b_type, self.c_type, self.d_type) def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint':