Skip to content
This repository was archived by the owner on Jul 17, 2024. It is now read-only.

feat: Add support for custom justifications #39

Merged
merged 4 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,19 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp

List<JavaInterfaceImplementor> nonObjectInterfaceImplementors = javaInterfaceImplementorSet.stream()
.filter(implementor -> !Object.class.equals(implementor.getInterfaceClass()))
.collect(Collectors.toList());
String[] interfaces = new String[nonObjectInterfaceImplementors.size()];
.toList();

String[] interfaces = new String[nonObjectInterfaceImplementors.size() + pythonCompiledClass.markerInterfaces.size()];
for (int i = 0; i < nonObjectInterfaceImplementors.size(); i++) {
interfaces[i] = Type.getInternalName(nonObjectInterfaceImplementors.get(i).getInterfaceClass());
}
for (int i = 0; i < pythonCompiledClass.markerInterfaces.size(); i++) {
var markerInterface = pythonCompiledClass.markerInterfaces.get(i);
if (!markerInterface.isInterface()) {
throw new IllegalArgumentException("%s is not an interface".formatted(markerInterface));
}
interfaces[i + nonObjectInterfaceImplementors.size()] = Type.getInternalName(markerInterface);
}

ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public class PythonCompiledClass {
*/
public Map<String, TypeHint> typeAnnotations;

/**
* Marker interfaces the class implement
*/
public List<Class<?>> markerInterfaces;

/**
* The binary type of this PythonCompiledClass;
* typically {@link CPythonType}. Used when methods
Expand Down
2 changes: 1 addition & 1 deletion jpyinterpreter/src/main/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This module acts as an interface to the Python bytecode to Java bytecode interpreter
"""
from .jvm_setup import init, set_class_output_directory
from .annotations import JavaAnnotation, add_class_annotation
from .annotations import JavaAnnotation, add_class_annotation, add_marker_interface
from .conversions import (convert_to_java_python_like_object, unwrap_python_like_object,
update_python_object_from_java, is_c_native)
from .translator import (translate_python_bytecode_to_java_bytecode,
Expand Down
13 changes: 13 additions & 0 deletions jpyinterpreter/src/main/python/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ def decorator(_cls: Type[T]) -> Type[T]:
return decorator


def add_marker_interface(marker_interface: JClass | str, /) -> Callable[[Type[T]], Type[T]]:
def decorator(_cls: Type[T]) -> Type[T]:
from .translator import type_to_compiled_java_class, type_to_marker_interfaces
if _cls in type_to_compiled_java_class:
raise RuntimeError('Cannot add a marker interface after a class been compiled.')
marker_interfaces = type_to_marker_interfaces.get(_cls, [])
marker_interfaces.append(marker_interface)
type_to_marker_interfaces[_cls] = marker_interfaces
return _cls

return decorator


def copy_type_annotations(hinted_object, default_args, vargs_name, kwargs_name):
from java.util import HashMap, Collections
from ai.timefold.jpyinterpreter import TypeHint
Expand Down
9 changes: 9 additions & 0 deletions jpyinterpreter/src/main/python/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
global_dict_to_key_set = dict()
type_to_compiled_java_class = dict()
type_to_annotations = dict()
type_to_marker_interfaces = dict()

function_interface_pair_to_instance = dict()
function_interface_pair_to_class = dict()
Expand Down Expand Up @@ -629,9 +630,17 @@ def translate_python_class_to_java_class(python_class):

python_compiled_class = PythonCompiledClass()
python_compiled_class.annotations = ArrayList()
python_compiled_class.markerInterfaces = ArrayList()

for annotation in type_to_annotations.get(python_class, []):
python_compiled_class.annotations.add(convert_java_annotation(annotation))

for marker_interface in type_to_marker_interfaces.get(python_class, []):
if isinstance(marker_interface, str):
marker_interface = JClass(marker_interface)

python_compiled_class.markerInterfaces.add(marker_interface)

python_compiled_class.binaryType = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class,
convert=True))
python_compiled_class.module = python_class.__module__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
Expand Down Expand Up @@ -41,7 +42,8 @@ public void testPythonClassTranslation() throws ClassNotFoundException, NoSuchMe
.op(ControlOpDescriptor.RETURN_VALUE)
.build();

compiledClass.annotations = List.of();
compiledClass.annotations = Collections.emptyList();
compiledClass.markerInterfaces = Collections.emptyList();
compiledClass.className = "MyClass";
compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE);
compiledClass.staticAttributeNameToObject = Map.of("type_variable", new PythonString("type_value"));
Expand Down Expand Up @@ -94,7 +96,8 @@ public void testPythonClassComparable() throws ClassNotFoundException {
PythonCompiledFunction comparisonFunction = getCompareFunction.apply(compareOp);

PythonCompiledClass compiledClass = new PythonCompiledClass();
compiledClass.annotations = List.of();
compiledClass.annotations = Collections.emptyList();
compiledClass.markerInterfaces = Collections.emptyList();
compiledClass.className = "MyClass";
compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE);
compiledClass.staticAttributeNameToObject = Map.of();
Expand Down Expand Up @@ -163,7 +166,8 @@ public void testPythonClassEqualsAndHashCode() throws ClassNotFoundException {
.build();

PythonCompiledClass compiledClass = new PythonCompiledClass();
compiledClass.annotations = List.of();
compiledClass.annotations = Collections.emptyList();
compiledClass.markerInterfaces = Collections.emptyList();
compiledClass.className = "MyClass";
compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE);
compiledClass.staticAttributeNameToObject = Map.of();
Expand Down
24 changes: 24 additions & 0 deletions jpyinterpreter/tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,27 @@ class A:
translated_class = translate_python_class_to_java_class(A).getJavaClass()
field_type = translated_class.getField('my_field').getGenericType()
assert field_type.getActualTypeArguments()[0].getName() == PythonString.class_.getName()


def test_marker_interface():
from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference
from jpyinterpreter import translate_python_class_to_java_class, add_marker_interface

@add_marker_interface(OpaquePythonReference)
class A:
pass

translated_class = translate_python_class_to_java_class(A).getJavaClass()
assert OpaquePythonReference.class_.isAssignableFrom(translated_class)


def test_marker_interface_string():
from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference
from jpyinterpreter import translate_python_class_to_java_class, add_marker_interface

@add_marker_interface('ai.timefold.jpyinterpreter.types.wrappers.OpaquePythonReference')
class A:
pass

translated_class = translate_python_class_to_java_class(A).getJavaClass()
assert OpaquePythonReference.class_.isAssignableFrom(translated_class)
48 changes: 43 additions & 5 deletions tests/test_constraint_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,48 @@ def define_constraints(constraint_factory: ConstraintFactory):
}


def test_custom_justifications():
@dataclass(unsafe_hash=True)
class MyJustification(ConstraintJustification):
code: str
score: SimpleScore

@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.reward(SimpleScore.ONE, lambda e: e.value.number)
.justify_with(lambda e, score: MyJustification(e.code, score))
.as_constraint('my_package', 'Maximize value')
]

score_manager = create_score_manager(define_constraints)
entity_a: Entity = Entity('A')
entity_b: Entity = Entity('B')

value_1 = Value(1)
value_2 = Value(2)
value_3 = Value(3)

entity_a.value = value_1
entity_b.value = value_3

problem = Solution([entity_a, entity_b], [value_1, value_2, value_3])

justifications = score_manager.explain(problem).get_justification_list()
assert len(justifications) == 2
assert MyJustification('A', SimpleScore.of(1)) in justifications
assert MyJustification('B', SimpleScore.of(3)) in justifications

justifications = score_manager.explain(problem).get_justification_list(MyJustification)
assert len(justifications) == 2
assert MyJustification('A', SimpleScore.of(1)) in justifications
assert MyJustification('B', SimpleScore.of(3)) in justifications

justifications = score_manager.explain(problem).get_justification_list(DefaultConstraintJustification)
assert len(justifications) == 0


ignored_python_functions = {
'_call_comparison_java_joiner',
'__init__',
Expand All @@ -534,23 +576,19 @@ def define_constraints(constraint_factory: ConstraintFactory):
'countLongBi', # Python has no concept of Long (everything a BigInteger)
'countLongQuad',
'countLongTri',
'_handler', # JPype handler field should be ignored
# Unimplemented penalize/reward/impact
'impactBigDecimal',
'impactConfigurable',
'impactConfigurableBigDecimal',
'impactConfigurableLong',
'impactLong',
'penalizeBigDecimal',
'penalizeConfigurable',
'penalizeConfigurableBigDecimal',
'penalizeConfigurableLong',
'penalizeLong',
'rewardBigDecimal',
'rewardConfigurable',
'rewardConfigurableBigDecimal',
'rewardConfigurableLong',
'rewardLong',
'_handler', # JPype handler field should be ignored
# These methods are deprecated
'from_',
'fromUnfiltered',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from ._solver_factory import SolverFactory
from ._solver_manager import SolverManager
from .._timefold_java_interop import get_class
from jpyinterpreter import unwrap_python_like_object
from jpyinterpreter import unwrap_python_like_object, add_marker_interface
from dataclasses import dataclass

from typing import TypeVar, Generic, Union, TYPE_CHECKING, Any, cast, Optional
from typing import TypeVar, Generic, Union, TYPE_CHECKING, Any, cast, Optional, Type

if TYPE_CHECKING:
# These imports require a JVM to be running, so only import if type checking
Expand All @@ -24,6 +24,7 @@
Solution_ = TypeVar('Solution_')
ProblemId_ = TypeVar('ProblemId_')
Score_ = TypeVar('Score_', bound='Score')
Justification_ = TypeVar('Justification_', bound='ConstraintJustification')


@dataclass(frozen=True, unsafe_hash=True)
Expand Down Expand Up @@ -100,8 +101,13 @@ def __hash__(self) -> int:
return combined_hash


@add_marker_interface('ai.timefold.solver.core.api.score.stream.ConstraintJustification')
class ConstraintJustification:
pass


@dataclass(frozen=True, eq=True)
class DefaultConstraintJustification:
class DefaultConstraintJustification(ConstraintJustification):
facts: tuple[Any, ...]
impact: Score_

Expand All @@ -127,7 +133,7 @@ def _map_constraint_match_set(constraint_match_set: set['_JavaConstraintMatch'])
}


def _unwrap_justification(justification: Any) -> Any:
def _unwrap_justification(justification: Any) -> ConstraintJustification:
from ai.timefold.solver.core.api.score.stream import (
DefaultConstraintJustification as _JavaDefaultConstraintJustification)
if isinstance(justification, _JavaDefaultConstraintJustification):
Expand All @@ -139,7 +145,7 @@ def _unwrap_justification(justification: Any) -> Any:
return unwrap_python_like_object(justification)


def _unwrap_justification_list(justification_list: list[Any]) -> list[Any]:
def _unwrap_justification_list(justification_list: list[Any]) -> list[ConstraintJustification]:
return [_unwrap_justification(justification) for justification in justification_list]


Expand All @@ -163,7 +169,7 @@ def constraint_match_set(self) -> set[ConstraintMatch[Score_]]:
def indicted_object(self) -> Any:
return unwrap_python_like_object(self._delegate.getIndictedObject())

def get_justification_list(self, justification_type=None) -> list[Any]:
def get_justification_list(self, justification_type: Type[Justification_] = None) -> list[Justification_]:
if justification_type is None:
justification_list = self._delegate.getJustificationList()
else:
Expand Down Expand Up @@ -250,7 +256,7 @@ def solution(self) -> Solution_:
def summary(self) -> str:
return self._delegate.getSummary()

def get_justification_list(self, justification_type=None) -> list[Any]:
def get_justification_list(self, justification_type: Type[Justification_] = None) -> list[Justification_]:
if justification_type is None:
justification_list = self._delegate.getJustificationList()
else:
Expand All @@ -275,7 +281,7 @@ def score(self) -> Score_:
return self._delegate.score()

@property
def justification(self) -> Any:
def justification(self) -> ConstraintJustification:
return _unwrap_justification(self._delegate.justification())


Expand Down Expand Up @@ -343,5 +349,5 @@ def constraint_analyses(self) -> list[ConstraintAnalysis]:

__all__ = ['SolutionManager', 'ScoreExplanation',
'ConstraintRef', 'ConstraintMatch', 'ConstraintMatchTotal',
'DefaultConstraintJustification', 'Indictment',
'ConstraintJustification', 'DefaultConstraintJustification', 'Indictment',
'ScoreAnalysis', 'ConstraintAnalysis', 'MatchAnalysis']
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._function_translator import function_cast
import timefold.solver.api as api
from typing import TypeVar, Callable, Generic, Collection, Any, TYPE_CHECKING, Type

if TYPE_CHECKING:
Expand Down Expand Up @@ -31,10 +32,11 @@ def indict_with(self, indictment_function: Callable[[A], Collection]) -> 'UniCon
return UniConstraintBuilder(self.delegate.indictWith(
function_cast(indictment_function, self.a_type)), self.a_type)

def justify_with(self, justification_function: Callable[[A, ScoreType], Any]) -> \
def justify_with(self, justification_function: Callable[[A, ScoreType], 'api.ConstraintJustification']) -> \
'UniConstraintBuilder[A, ScoreType]':
from ai.timefold.solver.core.api.score import Score
return UniConstraintBuilder(self.delegate.justifyWith(
function_cast(justification_function, self.a_type)), self.a_type)
function_cast(justification_function, self.a_type, Score)), self.a_type)

def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint':
if constraint_name is None:
Expand All @@ -58,10 +60,11 @@ def indict_with(self, indictment_function: Callable[[A, B], Collection]) -> 'BiC
return BiConstraintBuilder(self.delegate.indictWith(
function_cast(indictment_function, self.a_type, self.b_type)), self.a_type, self.b_type)

def justify_with(self, justification_function: Callable[[A, B, ScoreType], Any]) -> \
def justify_with(self, justification_function: Callable[[A, B, ScoreType], 'api.ConstraintJustification']) -> \
'BiConstraintBuilder[A, B, ScoreType]':
from ai.timefold.solver.core.api.score import Score
return BiConstraintBuilder(self.delegate.justifyWith(
function_cast(justification_function, self.a_type, self.b_type)), self.a_type, self.b_type)
function_cast(justification_function, self.a_type, self.b_type, Score)), self.a_type, self.b_type)

def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint':
if constraint_name is None:
Expand Down Expand Up @@ -89,11 +92,12 @@ def indict_with(self, indictment_function: Callable[[A, B, C], Collection]) -> \
function_cast(indictment_function, self.a_type, self.b_type, self.c_type)), self.a_type, self.b_type,
self.c_type)

def justify_with(self, justification_function: Callable[[A, B, C, ScoreType], Any]) -> \
def justify_with(self, justification_function: Callable[[A, B, C, ScoreType], 'api.ConstraintJustification']) -> \
'TriConstraintBuilder[A, B, C, ScoreType]':
from ai.timefold.solver.core.api.score import Score
return TriConstraintBuilder(self.delegate.justifyWith(
function_cast(justification_function, self.a_type, self.b_type, self.c_type)), self.a_type, self.b_type,
self.c_type)
function_cast(justification_function, self.a_type, self.b_type, self.c_type, Score)),
self.a_type, self.b_type, self.c_type)

def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint':
if constraint_name is None:
Expand Down Expand Up @@ -123,10 +127,11 @@ def indict_with(self, indictment_function: Callable[[A, B, C, D], Collection]) -
function_cast(indictment_function, self.a_type, self.b_type, self.c_type, self.d_type)),
self.a_type, self.b_type, self.c_type, self.d_type)

def justify_with(self, justification_function: Callable[[A, B, C, D, ScoreType], Any]) -> \
'QuadConstraintBuilder[A, B, C, D, ScoreType]':
def justify_with(self, justification_function: Callable[[A, B, C, D, ScoreType], 'api.ConstraintJustification']) \
-> 'QuadConstraintBuilder[A, B, C, D, ScoreType]':
from ai.timefold.solver.core.api.score import Score
return QuadConstraintBuilder(self.delegate.justifyWith(
function_cast(justification_function, self.a_type, self.b_type, self.c_type, self.d_type)),
function_cast(justification_function, self.a_type, self.b_type, self.c_type, self.d_type, Score)),
self.a_type, self.b_type, self.c_type, self.d_type)

def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint':
Expand Down
Loading