diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/InterfaceProxyGenerator.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/InterfaceProxyGenerator.java index 8dc1fb7..5e6e1ee 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/InterfaceProxyGenerator.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/InterfaceProxyGenerator.java @@ -2,13 +2,11 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.Collections; import java.util.HashSet; import java.util.IdentityHashMap; -import java.util.List; -import java.util.Map; import java.util.Set; +import ai.timefold.jpyinterpreter.implementors.DelegatingInterfaceImplementor; import ai.timefold.jpyinterpreter.implementors.JavaPythonTypeConversionImplementor; import ai.timefold.jpyinterpreter.types.BuiltinTypes; import ai.timefold.jpyinterpreter.types.PythonLikeType; @@ -16,7 +14,6 @@ import ai.timefold.jpyinterpreter.util.arguments.ArgumentSpec; import org.objectweb.asm.ClassWriter; -import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; import org.objectweb.asm.Type; @@ -251,27 +248,6 @@ private static void createMethodDelegate(ClassWriter classWriter, interfaceMethodVisitor.visitFieldInsn(Opcodes.GETSTATIC, wrapperInternalName, "argumentSpec$" + interfaceMethod.getName(), Type.getDescriptor(ArgumentSpec.class)); - interfaceMethodVisitor.visitLdcInsn(interfaceMethod.getParameterCount()); - interfaceMethodVisitor.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(PythonLikeObject.class)); - interfaceMethodVisitor.visitVarInsn(Opcodes.ASTORE, interfaceMethod.getParameterCount() + 2); - for (int i = 0; i < interfaceMethod.getParameterCount(); i++) { - var parameterType = interfaceMethod.getParameterTypes()[i]; - interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 2); - interfaceMethodVisitor.visitLdcInsn(i); - interfaceMethodVisitor.visitVarInsn(Type.getType(parameterType).getOpcode(Opcodes.ILOAD), - i + 1); - if (parameterType.isPrimitive()) { - convertPrimitiveToObjectType(parameterType, interfaceMethodVisitor); - } - interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 1); - interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, - Type.getInternalName(JavaPythonTypeConversionImplementor.class), - "wrapJavaObject", - Type.getMethodDescriptor(Type.getType(PythonLikeObject.class), Type.getType(Object.class), Type.getType( - Map.class)), - false); - interfaceMethodVisitor.visitInsn(Opcodes.AASTORE); - } var functionSignature = delegateType.getMethodType(interfaceMethod.getName()) .orElseThrow(() -> new IllegalArgumentException( @@ -279,28 +255,11 @@ private static void createMethodDelegate(ClassWriter classWriter, .formatted(delegateType, interfaceMethod.getDeclaringClass(), interfaceMethod))) .getDefaultFunctionSignature() .orElseThrow(); - - interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 2); - interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(List.class), - "of", Type.getMethodDescriptor(Type.getType(List.class), Type.getType(Object[].class)), - true); - interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Collections.class), - "emptyMap", Type.getMethodDescriptor(Type.getType(Map.class)), false); - interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(ArgumentSpec.class), - "extractArgumentList", Type.getMethodDescriptor( - Type.getType(List.class), Type.getType(List.class), Type.getType(Map.class)), + DelegatingInterfaceImplementor.prepareParametersForMethodCallFromArgumentSpec( + interfaceMethod, interfaceMethodVisitor, functionSignature.getParameterTypes().length, + Type.getType(functionSignature.getMethodDescriptor().getMethodDescriptor()), false); - for (int i = 0; i < functionSignature.getParameterTypes().length; i++) { - interfaceMethodVisitor.visitInsn(Opcodes.DUP); - interfaceMethodVisitor.visitLdcInsn(i); - interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(List.class), - "get", Type.getMethodDescriptor(Type.getType(Object.class), Type.INT_TYPE), true); - interfaceMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, - functionSignature.getParameterTypes()[i].getJavaTypeInternalName()); - interfaceMethodVisitor.visitInsn(Opcodes.SWAP); - } - interfaceMethodVisitor.visitInsn(Opcodes.POP); functionSignature.getMethodDescriptor().callMethod(interfaceMethodVisitor); var returnType = interfaceMethod.getReturnType(); @@ -308,7 +267,7 @@ private static void createMethodDelegate(ClassWriter classWriter, interfaceMethodVisitor.visitInsn(Opcodes.RETURN); } else { if (returnType.isPrimitive()) { - loadBoxedPrimitiveTypeClass(returnType, interfaceMethodVisitor); + DelegatingInterfaceImplementor.loadBoxedPrimitiveTypeClass(returnType, interfaceMethodVisitor); } else { interfaceMethodVisitor.visitLdcInsn(Type.getType(returnType)); } @@ -320,7 +279,7 @@ private static void createMethodDelegate(ClassWriter classWriter, PythonLikeObject.class)), false); if (returnType.isPrimitive()) { - unboxBoxedPrimitiveType(returnType, interfaceMethodVisitor); + DelegatingInterfaceImplementor.unboxBoxedPrimitiveType(returnType, interfaceMethodVisitor); interfaceMethodVisitor.visitInsn(Type.getType(returnType).getOpcode(Opcodes.IRETURN)); } else { interfaceMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(returnType)); @@ -330,94 +289,4 @@ private static void createMethodDelegate(ClassWriter classWriter, interfaceMethodVisitor.visitMaxs(interfaceMethod.getParameterCount() + 2, 1); interfaceMethodVisitor.visitEnd(); } - - private static void convertPrimitiveToObjectType(Class primitiveType, MethodVisitor methodVisitor) { - if (primitiveType.equals(boolean.class)) { - methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Boolean.class), - "valueOf", Type.getMethodDescriptor(Type.getType(Boolean.class), Type.BOOLEAN_TYPE), false); - } else if (primitiveType.equals(byte.class)) { - methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Byte.class), - "valueOf", Type.getMethodDescriptor(Type.getType(Byte.class), Type.BYTE_TYPE), false); - } else if (primitiveType.equals(char.class)) { - methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Character.class), - "valueOf", Type.getMethodDescriptor(Type.getType(Character.class), Type.CHAR_TYPE), false); - } else if (primitiveType.equals(short.class)) { - methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Short.class), - "valueOf", Type.getMethodDescriptor(Type.getType(Short.class), Type.SHORT_TYPE), false); - } else if (primitiveType.equals(int.class)) { - methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Integer.class), - "valueOf", Type.getMethodDescriptor(Type.getType(Integer.class), Type.INT_TYPE), false); - } else if (primitiveType.equals(long.class)) { - methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Long.class), - "valueOf", Type.getMethodDescriptor(Type.getType(Long.class), Type.LONG_TYPE), false); - } else if (primitiveType.equals(float.class)) { - methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Float.class), - "valueOf", Type.getMethodDescriptor(Type.getType(Float.class), Type.FLOAT_TYPE), false); - } else if (primitiveType.equals(double.class)) { - methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Double.class), - "valueOf", Type.getMethodDescriptor(Type.getType(Double.class), Type.DOUBLE_TYPE), false); - } else { - throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType)); - } - } - - private static void loadBoxedPrimitiveTypeClass(Class primitiveType, MethodVisitor methodVisitor) { - if (primitiveType.equals(boolean.class)) { - methodVisitor.visitLdcInsn(Type.getType(Boolean.class)); - } else if (primitiveType.equals(byte.class)) { - methodVisitor.visitLdcInsn(Type.getType(Byte.class)); - } else if (primitiveType.equals(char.class)) { - methodVisitor.visitLdcInsn(Type.getType(Character.class)); - } else if (primitiveType.equals(short.class)) { - methodVisitor.visitLdcInsn(Type.getType(Short.class)); - } else if (primitiveType.equals(int.class)) { - methodVisitor.visitLdcInsn(Type.getType(Integer.class)); - } else if (primitiveType.equals(long.class)) { - methodVisitor.visitLdcInsn(Type.getType(Long.class)); - } else if (primitiveType.equals(float.class)) { - methodVisitor.visitLdcInsn(Type.getType(Float.class)); - } else if (primitiveType.equals(double.class)) { - methodVisitor.visitLdcInsn(Type.getType(Double.class)); - } else { - throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType)); - } - } - - private static void unboxBoxedPrimitiveType(Class primitiveType, MethodVisitor methodVisitor) { - if (primitiveType.equals(boolean.class)) { - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Boolean.class)); - methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Boolean.class), - "booleanValue", Type.getMethodDescriptor(Type.BOOLEAN_TYPE), false); - } else if (primitiveType.equals(byte.class)) { - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Byte.class)); - methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Byte.class), - "byteValue", Type.getMethodDescriptor(Type.BYTE_TYPE), false); - } else if (primitiveType.equals(char.class)) { - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Character.class)); - methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Character.class), - "charValue", Type.getMethodDescriptor(Type.CHAR_TYPE), false); - } else if (primitiveType.equals(short.class)) { - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Short.class)); - methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Short.class), - "shortValue", Type.getMethodDescriptor(Type.SHORT_TYPE), false); - } else if (primitiveType.equals(int.class)) { - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Integer.class)); - methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Integer.class), - "intValue", Type.getMethodDescriptor(Type.INT_TYPE), false); - } else if (primitiveType.equals(long.class)) { - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Long.class)); - methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Long.class), - "longValue", Type.getMethodDescriptor(Type.LONG_TYPE), false); - } else if (primitiveType.equals(float.class)) { - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Float.class)); - methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Float.class), - "floatValue", Type.getMethodDescriptor(Type.FLOAT_TYPE), false); - } else if (primitiveType.equals(double.class)) { - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Double.class)); - methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Double.class), - "doubleValue", Type.getMethodDescriptor(Type.DOUBLE_TYPE), false); - } else { - throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType)); - } - } } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java index 48d071a..0984d65 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java @@ -19,6 +19,7 @@ import java.util.stream.Collectors; import ai.timefold.jpyinterpreter.dag.FlowGraph; +import ai.timefold.jpyinterpreter.implementors.DelegatingInterfaceImplementor; import ai.timefold.jpyinterpreter.implementors.JavaComparableImplementor; import ai.timefold.jpyinterpreter.implementors.JavaEqualsImplementor; import ai.timefold.jpyinterpreter.implementors.JavaHashCodeImplementor; @@ -94,6 +95,7 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp var className = preparedClassInfo.className; var internalClassName = preparedClassInfo.classInternalName; + Map instanceMethodNameToMethodDescriptor = new HashMap<>(); Set superTypeSet; Set javaInterfaceImplementorSet = new HashSet<>(); @@ -118,6 +120,11 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp } } + for (Class javaInterface : pythonCompiledClass.javaInterfaces) { + javaInterfaceImplementorSet.add( + new DelegatingInterfaceImplementor(internalClassName, javaInterface, instanceMethodNameToMethodDescriptor)); + } + if (pythonCompiledClass.superclassList.isEmpty()) { superTypeSet = Set.of(CPythonBackedPythonLikeObject.CPYTHON_BACKED_OBJECT_TYPE); } else { @@ -159,7 +166,8 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp List nonObjectInterfaceImplementors = javaInterfaceImplementorSet.stream() .filter(implementor -> !Object.class.equals(implementor.getInterfaceClass())) - .collect(Collectors.toList()); + .toList(); + String[] interfaces = new String[nonObjectInterfaceImplementors.size()]; for (int i = 0; i < nonObjectInterfaceImplementors.size(); i++) { interfaces[i] = Type.getInternalName(nonObjectInterfaceImplementors.get(i).getInterfaceClass()); @@ -294,7 +302,7 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp .entrySet()) { instanceMethodEntry.getValue().methodKind = PythonMethodKind.VIRTUAL_METHOD; createInstanceMethod(pythonLikeType, classWriter, internalClassName, instanceMethodEntry.getKey(), - instanceMethodEntry.getValue()); + instanceMethodEntry.getValue(), instanceMethodNameToMethodDescriptor); } for (Map.Entry staticMethodEntry : pythonCompiledClass.staticFunctionNameToPythonBytecode @@ -854,13 +862,15 @@ private static void addAnnotationsToMethod(PythonCompiledFunction function, Meth } private static void createInstanceMethod(PythonLikeType pythonLikeType, ClassWriter classWriter, String internalClassName, - String methodName, PythonCompiledFunction function) { + String methodName, PythonCompiledFunction function, + Map instanceMethodNameToMethodDescriptor) { InterfaceDeclaration interfaceDeclaration = getInterfaceForInstancePythonFunction(internalClassName, function); - String interfaceDescriptor = 'L' + interfaceDeclaration.interfaceName + ';'; + String interfaceDescriptor = interfaceDeclaration.descriptor(); String javaMethodName = getJavaMethodName(methodName); classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, interfaceDescriptor, null, null); + instanceMethodNameToMethodDescriptor.put(methodName, interfaceDeclaration); Type returnType = getVirtualFunctionReturnType(function); List parameterPythonTypeList = function.getParameterTypes(); @@ -1555,30 +1565,13 @@ public static PythonLikeType getPythonReturnTypeOfFunction(PythonCompiledFunctio } } - public static class InterfaceDeclaration { - final String interfaceName; - final String methodDescriptor; - - public InterfaceDeclaration(String interfaceName, String methodDescriptor) { - this.interfaceName = interfaceName; - this.methodDescriptor = methodDescriptor; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - InterfaceDeclaration that = (InterfaceDeclaration) o; - return interfaceName.equals(that.interfaceName) && methodDescriptor.equals(that.methodDescriptor); + public record InterfaceDeclaration(String interfaceName, String methodDescriptor) { + public String descriptor() { + return "L" + interfaceName + ";"; } - @Override - public int hashCode() { - return Objects.hash(interfaceName, methodDescriptor); + public Type methodType() { + return Type.getMethodType(methodDescriptor); } } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java index 3b0a6aa..9210cfe 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java @@ -36,6 +36,11 @@ public class PythonCompiledClass { */ public Map typeAnnotations; + /** + * Java interfaces the class implement + */ + public List> javaInterfaces; + /** * The binary type of this PythonCompiledClass; * typically {@link CPythonType}. Used when methods diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/DelegatingInterfaceImplementor.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/DelegatingInterfaceImplementor.java new file mode 100644 index 0000000..3365750 --- /dev/null +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/DelegatingInterfaceImplementor.java @@ -0,0 +1,280 @@ +package ai.timefold.jpyinterpreter.implementors; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; + +import ai.timefold.jpyinterpreter.PythonBytecodeToJavaBytecodeTranslator; +import ai.timefold.jpyinterpreter.PythonClassTranslator; +import ai.timefold.jpyinterpreter.PythonCompiledClass; +import ai.timefold.jpyinterpreter.PythonLikeObject; +import ai.timefold.jpyinterpreter.types.errors.TypeError; +import ai.timefold.jpyinterpreter.util.MethodVisitorAdapters; +import ai.timefold.jpyinterpreter.util.arguments.ArgumentSpec; + +import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; + +public class DelegatingInterfaceImplementor extends JavaInterfaceImplementor { + final String internalClassName; + final Class interfaceClass; + final Map methodNameToFieldDescriptor; + + public DelegatingInterfaceImplementor(String internalClassName, Class interfaceClass, + Map methodNameToFieldDescriptor) { + this.internalClassName = internalClassName; + this.interfaceClass = interfaceClass; + this.methodNameToFieldDescriptor = methodNameToFieldDescriptor; + } + + @Override + public Class getInterfaceClass() { + return interfaceClass; + } + + @Override + public void implement(ClassWriter classWriter, PythonCompiledClass compiledClass) { + for (Method method : interfaceClass.getMethods()) { + if (!Modifier.isStatic(method.getModifiers()) && method.getDeclaringClass().isInterface()) { + implementMethod(classWriter, compiledClass, method); + } + } + } + + private void implementMethod(ClassWriter classWriter, PythonCompiledClass compiledClass, Method interfaceMethod) { + if (!methodNameToFieldDescriptor.containsKey(interfaceMethod.getName())) { + if (interfaceMethod.isDefault()) { + return; + } else { + throw new TypeError("Class %s cannot implement interface %s because it does not implement method %s." + .formatted(compiledClass.className, interfaceMethod.getDeclaringClass().getName(), + interfaceMethod.getName())); + } + } + var interfaceMethodDescriptor = Type.getMethodDescriptor(interfaceMethod); + + // Generates interfaceMethod(A a, B b, ...) { return delegate.interfaceMethod(a, b, ...); } + var interfaceMethodVisitor = classWriter.visitMethod(Modifier.PUBLIC, interfaceMethod.getName(), + interfaceMethodDescriptor, null, null); + + interfaceMethodVisitor = + MethodVisitorAdapters.adapt(interfaceMethodVisitor, interfaceMethod.getName(), interfaceMethodDescriptor); + + for (var parameter : interfaceMethod.getParameters()) { + interfaceMethodVisitor.visitParameter(parameter.getName(), 0); + } + interfaceMethodVisitor.visitCode(); + interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, 0); + interfaceMethodVisitor.visitTypeInsn(Opcodes.NEW, Type.getInternalName(IdentityHashMap.class)); + interfaceMethodVisitor.visitInsn(Opcodes.DUP); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(IdentityHashMap.class), + "", Type.getMethodDescriptor(Type.VOID_TYPE), false); + interfaceMethodVisitor.visitVarInsn(Opcodes.ASTORE, interfaceMethod.getParameterCount() + 1); + + // Generates TOS = MyClass.methodInstance.getClass().getField(ARGUMENT_SPEC_INSTANCE_FIELD).get(MyClass.methodInstance); + var functionInterfaceDeclaration = methodNameToFieldDescriptor.get(interfaceMethod.getName()); + interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, 0); + interfaceMethodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, + PythonClassTranslator.getJavaMethodName(interfaceMethod.getName()), + functionInterfaceDeclaration.descriptor()); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Object.class), + "getClass", Type.getMethodDescriptor(Type.getType(Class.class)), false); + interfaceMethodVisitor.visitLdcInsn(PythonBytecodeToJavaBytecodeTranslator.ARGUMENT_SPEC_INSTANCE_FIELD_NAME); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class), + "getField", Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class)), false); + interfaceMethodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, + PythonClassTranslator.getJavaMethodName(interfaceMethod.getName()), + functionInterfaceDeclaration.descriptor()); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class), + "get", Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class)), false); + interfaceMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(ArgumentSpec.class)); + var methodType = functionInterfaceDeclaration.methodType(); + int argumentCount = methodType.getArgumentCount(); + + prepareParametersForMethodCallFromArgumentSpec(interfaceMethod, interfaceMethodVisitor, argumentCount, methodType, + true); + + Type[] javaParameterTypes = new Type[Math.max(0, argumentCount - 1)]; + + for (int i = 1; i < argumentCount; i++) { + javaParameterTypes[i - 1] = methodType.getArgumentTypes()[i]; + } + String javaMethodDescriptor = Type.getMethodDescriptor(methodType.getReturnType(), javaParameterTypes); + + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, internalClassName, + PythonClassTranslator.getJavaMethodName(interfaceMethod.getName()), + javaMethodDescriptor, false); + + var returnType = interfaceMethod.getReturnType(); + if (returnType.equals(void.class)) { + interfaceMethodVisitor.visitInsn(Opcodes.RETURN); + } else { + if (returnType.isPrimitive()) { + loadBoxedPrimitiveTypeClass(returnType, interfaceMethodVisitor); + } else { + interfaceMethodVisitor.visitLdcInsn(Type.getType(returnType)); + } + interfaceMethodVisitor.visitInsn(Opcodes.SWAP); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, + Type.getInternalName(JavaPythonTypeConversionImplementor.class), + "convertPythonObjectToJavaType", + Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Class.class), Type.getType( + PythonLikeObject.class)), + false); + if (returnType.isPrimitive()) { + unboxBoxedPrimitiveType(returnType, interfaceMethodVisitor); + interfaceMethodVisitor.visitInsn(Type.getType(returnType).getOpcode(Opcodes.IRETURN)); + } else { + interfaceMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(returnType)); + interfaceMethodVisitor.visitInsn(Opcodes.ARETURN); + } + } + interfaceMethodVisitor.visitMaxs(interfaceMethod.getParameterCount() + 2, 1); + interfaceMethodVisitor.visitEnd(); + } + + public static void prepareParametersForMethodCallFromArgumentSpec(Method interfaceMethod, + MethodVisitor interfaceMethodVisitor, int argumentCount, + Type methodType, boolean skipSelf) { + int parameterStart = skipSelf ? 1 : 0; + interfaceMethodVisitor.visitLdcInsn(interfaceMethod.getParameterCount()); + interfaceMethodVisitor.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(PythonLikeObject.class)); + interfaceMethodVisitor.visitVarInsn(Opcodes.ASTORE, interfaceMethod.getParameterCount() + 2); + for (int i = 0; i < interfaceMethod.getParameterCount(); i++) { + var parameterType = interfaceMethod.getParameterTypes()[i]; + interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 2); + interfaceMethodVisitor.visitLdcInsn(i); + interfaceMethodVisitor.visitVarInsn(Type.getType(parameterType).getOpcode(Opcodes.ILOAD), + i + 1); + if (parameterType.isPrimitive()) { + convertPrimitiveToObjectType(parameterType, interfaceMethodVisitor); + } + interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 1); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, + Type.getInternalName(JavaPythonTypeConversionImplementor.class), + "wrapJavaObject", + Type.getMethodDescriptor(Type.getType(PythonLikeObject.class), Type.getType(Object.class), Type.getType( + Map.class)), + false); + interfaceMethodVisitor.visitInsn(Opcodes.AASTORE); + } + + interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, interfaceMethod.getParameterCount() + 2); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(List.class), + "of", Type.getMethodDescriptor(Type.getType(List.class), Type.getType(Object[].class)), + true); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Collections.class), + "emptyMap", Type.getMethodDescriptor(Type.getType(Map.class)), false); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(ArgumentSpec.class), + "extractArgumentList", Type.getMethodDescriptor( + Type.getType(List.class), Type.getType(List.class), Type.getType(Map.class)), + false); + + for (int i = 0; i < argumentCount - parameterStart; i++) { + interfaceMethodVisitor.visitInsn(Opcodes.DUP); + interfaceMethodVisitor.visitLdcInsn(i); + interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(List.class), + "get", Type.getMethodDescriptor(Type.getType(Object.class), Type.INT_TYPE), true); + interfaceMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, + methodType.getArgumentTypes()[i + parameterStart].getInternalName()); + interfaceMethodVisitor.visitInsn(Opcodes.SWAP); + } + interfaceMethodVisitor.visitInsn(Opcodes.POP); + } + + public static void convertPrimitiveToObjectType(Class primitiveType, MethodVisitor methodVisitor) { + if (primitiveType.equals(boolean.class)) { + methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Boolean.class), + "valueOf", Type.getMethodDescriptor(Type.getType(Boolean.class), Type.BOOLEAN_TYPE), false); + } else if (primitiveType.equals(byte.class)) { + methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Byte.class), + "valueOf", Type.getMethodDescriptor(Type.getType(Byte.class), Type.BYTE_TYPE), false); + } else if (primitiveType.equals(char.class)) { + methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Character.class), + "valueOf", Type.getMethodDescriptor(Type.getType(Character.class), Type.CHAR_TYPE), false); + } else if (primitiveType.equals(short.class)) { + methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Short.class), + "valueOf", Type.getMethodDescriptor(Type.getType(Short.class), Type.SHORT_TYPE), false); + } else if (primitiveType.equals(int.class)) { + methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Integer.class), + "valueOf", Type.getMethodDescriptor(Type.getType(Integer.class), Type.INT_TYPE), false); + } else if (primitiveType.equals(long.class)) { + methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Long.class), + "valueOf", Type.getMethodDescriptor(Type.getType(Long.class), Type.LONG_TYPE), false); + } else if (primitiveType.equals(float.class)) { + methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Float.class), + "valueOf", Type.getMethodDescriptor(Type.getType(Float.class), Type.FLOAT_TYPE), false); + } else if (primitiveType.equals(double.class)) { + methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Double.class), + "valueOf", Type.getMethodDescriptor(Type.getType(Double.class), Type.DOUBLE_TYPE), false); + } else { + throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType)); + } + } + + public static void loadBoxedPrimitiveTypeClass(Class primitiveType, MethodVisitor methodVisitor) { + if (primitiveType.equals(boolean.class)) { + methodVisitor.visitLdcInsn(Type.getType(Boolean.class)); + } else if (primitiveType.equals(byte.class)) { + methodVisitor.visitLdcInsn(Type.getType(Byte.class)); + } else if (primitiveType.equals(char.class)) { + methodVisitor.visitLdcInsn(Type.getType(Character.class)); + } else if (primitiveType.equals(short.class)) { + methodVisitor.visitLdcInsn(Type.getType(Short.class)); + } else if (primitiveType.equals(int.class)) { + methodVisitor.visitLdcInsn(Type.getType(Integer.class)); + } else if (primitiveType.equals(long.class)) { + methodVisitor.visitLdcInsn(Type.getType(Long.class)); + } else if (primitiveType.equals(float.class)) { + methodVisitor.visitLdcInsn(Type.getType(Float.class)); + } else if (primitiveType.equals(double.class)) { + methodVisitor.visitLdcInsn(Type.getType(Double.class)); + } else { + throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType)); + } + } + + public static void unboxBoxedPrimitiveType(Class primitiveType, MethodVisitor methodVisitor) { + if (primitiveType.equals(boolean.class)) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Boolean.class)); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Boolean.class), + "booleanValue", Type.getMethodDescriptor(Type.BOOLEAN_TYPE), false); + } else if (primitiveType.equals(byte.class)) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Byte.class)); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Byte.class), + "byteValue", Type.getMethodDescriptor(Type.BYTE_TYPE), false); + } else if (primitiveType.equals(char.class)) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Character.class)); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Character.class), + "charValue", Type.getMethodDescriptor(Type.CHAR_TYPE), false); + } else if (primitiveType.equals(short.class)) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Short.class)); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Short.class), + "shortValue", Type.getMethodDescriptor(Type.SHORT_TYPE), false); + } else if (primitiveType.equals(int.class)) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Integer.class)); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Integer.class), + "intValue", Type.getMethodDescriptor(Type.INT_TYPE), false); + } else if (primitiveType.equals(long.class)) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Long.class)); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Long.class), + "longValue", Type.getMethodDescriptor(Type.LONG_TYPE), false); + } else if (primitiveType.equals(float.class)) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Float.class)); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Float.class), + "floatValue", Type.getMethodDescriptor(Type.FLOAT_TYPE), false); + } else if (primitiveType.equals(double.class)) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Double.class)); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Double.class), + "doubleValue", Type.getMethodDescriptor(Type.DOUBLE_TYPE), false); + } else { + throw new IllegalStateException("Unknown primitive type %s.".formatted(primitiveType)); + } + } +} diff --git a/jpyinterpreter/src/main/python/__init__.py b/jpyinterpreter/src/main/python/__init__.py index d9a825a..f3dfb8e 100644 --- a/jpyinterpreter/src/main/python/__init__.py +++ b/jpyinterpreter/src/main/python/__init__.py @@ -2,7 +2,7 @@ This module acts as an interface to the Python bytecode to Java bytecode interpreter """ from .jvm_setup import init, set_class_output_directory -from .annotations import JavaAnnotation, add_class_annotation +from .annotations import JavaAnnotation, add_class_annotation, add_java_interface from .conversions import (convert_to_java_python_like_object, unwrap_python_like_object, update_python_object_from_java, is_c_native) from .translator import (translate_python_bytecode_to_java_bytecode, diff --git a/jpyinterpreter/src/main/python/annotations.py b/jpyinterpreter/src/main/python/annotations.py index b3d9a35..c65bd50 100644 --- a/jpyinterpreter/src/main/python/annotations.py +++ b/jpyinterpreter/src/main/python/annotations.py @@ -20,7 +20,7 @@ def add_class_annotation(annotation_type, /, **annotation_values: Any) -> Callab def decorator(_cls: Type[T]) -> Type[T]: from .translator import type_to_compiled_java_class, type_to_annotations if _cls in type_to_compiled_java_class: - raise RuntimeError('Cannot add an annotation after a class been compiled.') + raise RuntimeError('Cannot add an annotation after a class has been compiled.') annotations = type_to_annotations.get(_cls, []) annotation = JavaAnnotation(annotation_type, annotation_values) annotations.append(annotation) @@ -30,6 +30,19 @@ def decorator(_cls: Type[T]) -> Type[T]: return decorator +def add_java_interface(java_interface: JClass | str, /) -> Callable[[Type[T]], Type[T]]: + def decorator(_cls: Type[T]) -> Type[T]: + from .translator import type_to_compiled_java_class, type_to_java_interfaces + if _cls in type_to_compiled_java_class: + raise RuntimeError('Cannot add an interface after a class has been compiled.') + marker_interfaces = type_to_java_interfaces.get(_cls, []) + marker_interfaces.append(java_interface) + type_to_java_interfaces[_cls] = marker_interfaces + return _cls + + return decorator + + def copy_type_annotations(hinted_object, default_args, vargs_name, kwargs_name): from java.util import HashMap, Collections from ai.timefold.jpyinterpreter import TypeHint diff --git a/jpyinterpreter/src/main/python/translator.py b/jpyinterpreter/src/main/python/translator.py index 90c399e..0ca0c85 100644 --- a/jpyinterpreter/src/main/python/translator.py +++ b/jpyinterpreter/src/main/python/translator.py @@ -13,6 +13,7 @@ global_dict_to_key_set = dict() type_to_compiled_java_class = dict() type_to_annotations = dict() +type_to_java_interfaces = dict() function_interface_pair_to_instance = dict() function_interface_pair_to_class = dict() @@ -629,9 +630,17 @@ def translate_python_class_to_java_class(python_class): python_compiled_class = PythonCompiledClass() python_compiled_class.annotations = ArrayList() + python_compiled_class.javaInterfaces = ArrayList() + for annotation in type_to_annotations.get(python_class, []): python_compiled_class.annotations.add(convert_java_annotation(annotation)) + for java_interface in type_to_java_interfaces.get(python_class, []): + if isinstance(java_interface, str): + java_interface = JClass(java_interface) + + python_compiled_class.javaInterfaces.add(java_interface) + python_compiled_class.binaryType = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class, convert=True)) python_compiled_class.module = python_class.__module__ diff --git a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java index f6b94d9..6148c39 100644 --- a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java +++ b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java @@ -2,15 +2,19 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.function.Function; +import java.util.function.ToIntFunction; import ai.timefold.jpyinterpreter.opcodes.descriptor.ControlOpDescriptor; +import ai.timefold.jpyinterpreter.opcodes.descriptor.DunderOpDescriptor; import ai.timefold.jpyinterpreter.types.BuiltinTypes; import ai.timefold.jpyinterpreter.types.PythonLikeFunction; import ai.timefold.jpyinterpreter.types.PythonLikeType; import ai.timefold.jpyinterpreter.types.PythonString; +import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple; import ai.timefold.jpyinterpreter.types.numeric.PythonInteger; import ai.timefold.jpyinterpreter.util.PythonFunctionBuilder; @@ -41,7 +45,8 @@ public void testPythonClassTranslation() throws ClassNotFoundException, NoSuchMe .op(ControlOpDescriptor.RETURN_VALUE) .build(); - compiledClass.annotations = List.of(); + compiledClass.annotations = Collections.emptyList(); + compiledClass.javaInterfaces = Collections.emptyList(); compiledClass.className = "MyClass"; compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); compiledClass.staticAttributeNameToObject = Map.of("type_variable", new PythonString("type_value")); @@ -94,7 +99,8 @@ public void testPythonClassComparable() throws ClassNotFoundException { PythonCompiledFunction comparisonFunction = getCompareFunction.apply(compareOp); PythonCompiledClass compiledClass = new PythonCompiledClass(); - compiledClass.annotations = List.of(); + compiledClass.annotations = Collections.emptyList(); + compiledClass.javaInterfaces = Collections.emptyList(); compiledClass.className = "MyClass"; compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); compiledClass.staticAttributeNameToObject = Map.of(); @@ -163,7 +169,8 @@ public void testPythonClassEqualsAndHashCode() throws ClassNotFoundException { .build(); PythonCompiledClass compiledClass = new PythonCompiledClass(); - compiledClass.annotations = List.of(); + compiledClass.annotations = Collections.emptyList(); + compiledClass.javaInterfaces = Collections.emptyList(); compiledClass.className = "MyClass"; compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); compiledClass.staticAttributeNameToObject = Map.of(); @@ -209,4 +216,113 @@ public void testPythonClassEqualsAndHashCode() throws ClassNotFoundException { assertThat(object3.hashCode()) .isEqualTo(PythonInteger.valueOf(Long.MAX_VALUE).hashCode()); } + + @Test + public void testPythonClassSimpleInterface() throws ClassNotFoundException { + PythonCompiledFunction initFunction = PythonFunctionBuilder.newFunction("self", "value") + .loadParameter("value") + .loadParameter("self") + .storeAttribute("value") + .loadConstant(null) + .op(ControlOpDescriptor.RETURN_VALUE) + .build(); + + PythonCompiledFunction applyAsInt = PythonFunctionBuilder.newFunction("self", "value") + .loadParameter("self") + .getAttribute("value") + .loadParameter("value") + .op(DunderOpDescriptor.BINARY_ADD) + .op(ControlOpDescriptor.RETURN_VALUE) + .build(); + + PythonCompiledClass compiledClass = new PythonCompiledClass(); + compiledClass.annotations = Collections.emptyList(); + compiledClass.javaInterfaces = List.of(ToIntFunction.class); + compiledClass.className = "MyClass"; + compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); + compiledClass.staticAttributeNameToObject = Map.of(); + compiledClass.staticAttributeNameToClassInstance = Map.of(); + compiledClass.typeAnnotations = Map.of("key", TypeHint.withoutAnnotations(BuiltinTypes.INT_TYPE)); + compiledClass.instanceFunctionNameToPythonBytecode = Map.of("__init__", initFunction, + "applyAsInt", applyAsInt); + compiledClass.staticFunctionNameToPythonBytecode = Map.of(); + compiledClass.classFunctionNameToPythonBytecode = Map.of(); + + PythonLikeType classType = PythonClassTranslator.translatePythonClass(compiledClass); + Class generatedClass = BuiltinTypes.asmClassLoader.loadClass( + classType.getJavaTypeInternalName().replace('/', '.')); + + assertThat(generatedClass).hasPublicFields(PythonClassTranslator.getJavaFieldName("value")); + assertThat(generatedClass).hasPublicMethods( + PythonClassTranslator.getJavaMethodName("__init__"), + "applyAsInt"); + assertThat(generatedClass).isAssignableTo(ToIntFunction.class); + + var object1 = (ToIntFunction) classType.$call(List.of(PythonInteger.valueOf(1)), Map.of(), null); + var object2 = (ToIntFunction) classType.$call(List.of(PythonInteger.valueOf(2)), Map.of(), null); + var object3 = (ToIntFunction) classType.$call(List.of(PythonInteger.valueOf(3)), Map.of(), null); + + assertThat(object1.applyAsInt(PythonInteger.valueOf(1))).isEqualTo(2); + assertThat(object2.applyAsInt(PythonInteger.valueOf(1))).isEqualTo(3); + assertThat(object3.applyAsInt(PythonInteger.valueOf(1))).isEqualTo(4); + } + + public interface ComplexInterface { + int STATIC_FIELD = 10; + + static int staticMethod() { + return STATIC_FIELD; + } + + default void defaultMethod() { + } + + int overloadedMethod(); + + int overloadedMethod(int value); + } + + @Test + public void testPythonClassComplexInterface() throws ClassNotFoundException { + PythonCompiledFunction initFunction = PythonFunctionBuilder.newFunction("self") + .loadConstant(null) + .op(ControlOpDescriptor.RETURN_VALUE) + .build(); + + PythonCompiledFunction overloadedMethod = PythonFunctionBuilder.newFunction("self", "value") + .loadParameter("value") + .loadConstant(1) + .op(DunderOpDescriptor.BINARY_ADD) + .op(ControlOpDescriptor.RETURN_VALUE) + .build(); + + overloadedMethod.defaultPositionalArguments = PythonLikeTuple.fromItems(PythonInteger.ZERO); + + PythonCompiledClass compiledClass = new PythonCompiledClass(); + compiledClass.annotations = Collections.emptyList(); + compiledClass.javaInterfaces = List.of(ComplexInterface.class); + compiledClass.className = "MyClass"; + compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); + compiledClass.staticAttributeNameToObject = Map.of(); + compiledClass.staticAttributeNameToClassInstance = Map.of(); + compiledClass.typeAnnotations = Map.of("key", TypeHint.withoutAnnotations(BuiltinTypes.INT_TYPE)); + compiledClass.instanceFunctionNameToPythonBytecode = Map.of("__init__", initFunction, + "overloadedMethod", overloadedMethod); + compiledClass.staticFunctionNameToPythonBytecode = Map.of(); + compiledClass.classFunctionNameToPythonBytecode = Map.of(); + + PythonLikeType classType = PythonClassTranslator.translatePythonClass(compiledClass); + Class generatedClass = BuiltinTypes.asmClassLoader.loadClass( + classType.getJavaTypeInternalName().replace('/', '.')); + + assertThat(generatedClass).hasPublicMethods( + PythonClassTranslator.getJavaMethodName("__init__"), + "overloadedMethod"); + assertThat(generatedClass).isAssignableTo(ComplexInterface.class); + + var instance = (ComplexInterface) classType.$call(List.of(), Map.of(), null); + + assertThat(instance.overloadedMethod()).isEqualTo(1); + assertThat(instance.overloadedMethod(1)).isEqualTo(2); + } } diff --git a/jpyinterpreter/tests/test_builtins.py b/jpyinterpreter/tests/test_builtins.py index b210b84..23c9afb 100644 --- a/jpyinterpreter/tests/test_builtins.py +++ b/jpyinterpreter/tests/test_builtins.py @@ -16,7 +16,8 @@ def _get_exception_with_cause(exception): if exception is None: return None try: - raise Exception(f'{exception.getClass().getSimpleName()}: {exception.getMessage()}') + raise Exception(f'{exception.getClass().getSimpleName()}: {exception.getMessage()}\n' + f'{exception.stacktrace()}') except Exception as e: cause = _JavaException._get_exception_with_cause(exception.getCause()) if cause is not None: diff --git a/jpyinterpreter/tests/test_classes.py b/jpyinterpreter/tests/test_classes.py index 6c5e755..d6f8fb0 100644 --- a/jpyinterpreter/tests/test_classes.py +++ b/jpyinterpreter/tests/test_classes.py @@ -962,3 +962,42 @@ class A: translated_class = translate_python_class_to_java_class(A).getJavaClass() field_type = translated_class.getField('my_field').getGenericType() assert field_type.getActualTypeArguments()[0].getName() == PythonString.class_.getName() + + +def test_marker_interface(): + from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference + from jpyinterpreter import translate_python_class_to_java_class, add_java_interface + + @add_java_interface(OpaquePythonReference) + class A: + pass + + translated_class = translate_python_class_to_java_class(A).getJavaClass() + assert OpaquePythonReference.class_.isAssignableFrom(translated_class) + + +def test_marker_interface_string(): + from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference + from jpyinterpreter import translate_python_class_to_java_class, add_java_interface + + @add_java_interface('ai.timefold.jpyinterpreter.types.wrappers.OpaquePythonReference') + class A: + pass + + translated_class = translate_python_class_to_java_class(A).getJavaClass() + assert OpaquePythonReference.class_.isAssignableFrom(translated_class) + + +def test_functional_interface(): + from java.util.function import ToIntFunction + from jpyinterpreter import translate_python_class_to_java_class, add_java_interface + + @add_java_interface(ToIntFunction) + class A: + def applyAsInt(self, argument: int): + return argument + 1 + + translated_class = translate_python_class_to_java_class(A).getJavaClass() + assert ToIntFunction.class_.isAssignableFrom(translated_class) + java_object = translated_class.getConstructor().newInstance() + assert java_object.applyAsInt(1) == 2 diff --git a/tests/test_constraint_streams.py b/tests/test_constraint_streams.py index dae8d07..5aad0d6 100644 --- a/tests/test_constraint_streams.py +++ b/tests/test_constraint_streams.py @@ -517,6 +517,48 @@ def define_constraints(constraint_factory: ConstraintFactory): } +def test_custom_justifications(): + @dataclass(unsafe_hash=True) + class MyJustification(ConstraintJustification): + code: str + score: SimpleScore + + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .reward(SimpleScore.ONE, lambda e: e.value.number) + .justify_with(lambda e, score: MyJustification(e.code, score)) + .as_constraint('my_package', 'Maximize value') + ] + + score_manager = create_score_manager(define_constraints) + entity_a: Entity = Entity('A') + entity_b: Entity = Entity('B') + + value_1 = Value(1) + value_2 = Value(2) + value_3 = Value(3) + + entity_a.value = value_1 + entity_b.value = value_3 + + problem = Solution([entity_a, entity_b], [value_1, value_2, value_3]) + + justifications = score_manager.explain(problem).get_justification_list() + assert len(justifications) == 2 + assert MyJustification('A', SimpleScore.of(1)) in justifications + assert MyJustification('B', SimpleScore.of(3)) in justifications + + justifications = score_manager.explain(problem).get_justification_list(MyJustification) + assert len(justifications) == 2 + assert MyJustification('A', SimpleScore.of(1)) in justifications + assert MyJustification('B', SimpleScore.of(3)) in justifications + + justifications = score_manager.explain(problem).get_justification_list(DefaultConstraintJustification) + assert len(justifications) == 0 + + ignored_python_functions = { '_call_comparison_java_joiner', '__init__', @@ -534,23 +576,19 @@ def define_constraints(constraint_factory: ConstraintFactory): 'countLongBi', # Python has no concept of Long (everything a BigInteger) 'countLongQuad', 'countLongTri', - '_handler', # JPype handler field should be ignored - # Unimplemented penalize/reward/impact 'impactBigDecimal', - 'impactConfigurable', 'impactConfigurableBigDecimal', 'impactConfigurableLong', 'impactLong', 'penalizeBigDecimal', - 'penalizeConfigurable', 'penalizeConfigurableBigDecimal', 'penalizeConfigurableLong', 'penalizeLong', 'rewardBigDecimal', - 'rewardConfigurable', 'rewardConfigurableBigDecimal', 'rewardConfigurableLong', 'rewardLong', + '_handler', # JPype handler field should be ignored # These methods are deprecated 'from_', 'fromUnfiltered', diff --git a/timefold-solver-python-core/src/main/python/api/_solution_manager.py b/timefold-solver-python-core/src/main/python/api/_solution_manager.py index 2e226aa..1cbbcd6 100644 --- a/timefold-solver-python-core/src/main/python/api/_solution_manager.py +++ b/timefold-solver-python-core/src/main/python/api/_solution_manager.py @@ -1,10 +1,10 @@ from ._solver_factory import SolverFactory from ._solver_manager import SolverManager from .._timefold_java_interop import get_class -from jpyinterpreter import unwrap_python_like_object +from jpyinterpreter import unwrap_python_like_object, add_java_interface from dataclasses import dataclass -from typing import TypeVar, Generic, Union, TYPE_CHECKING, Any, cast, Optional +from typing import TypeVar, Generic, Union, TYPE_CHECKING, Any, cast, Optional, Type if TYPE_CHECKING: # These imports require a JVM to be running, so only import if type checking @@ -24,6 +24,7 @@ Solution_ = TypeVar('Solution_') ProblemId_ = TypeVar('ProblemId_') Score_ = TypeVar('Score_', bound='Score') +Justification_ = TypeVar('Justification_', bound='ConstraintJustification') @dataclass(frozen=True, unsafe_hash=True) @@ -100,8 +101,13 @@ def __hash__(self) -> int: return combined_hash +@add_java_interface('ai.timefold.solver.core.api.score.stream.ConstraintJustification') +class ConstraintJustification: + pass + + @dataclass(frozen=True, eq=True) -class DefaultConstraintJustification: +class DefaultConstraintJustification(ConstraintJustification): facts: tuple[Any, ...] impact: Score_ @@ -127,7 +133,7 @@ def _map_constraint_match_set(constraint_match_set: set['_JavaConstraintMatch']) } -def _unwrap_justification(justification: Any) -> Any: +def _unwrap_justification(justification: Any) -> ConstraintJustification: from ai.timefold.solver.core.api.score.stream import ( DefaultConstraintJustification as _JavaDefaultConstraintJustification) if isinstance(justification, _JavaDefaultConstraintJustification): @@ -139,7 +145,7 @@ def _unwrap_justification(justification: Any) -> Any: return unwrap_python_like_object(justification) -def _unwrap_justification_list(justification_list: list[Any]) -> list[Any]: +def _unwrap_justification_list(justification_list: list[Any]) -> list[ConstraintJustification]: return [_unwrap_justification(justification) for justification in justification_list] @@ -163,7 +169,7 @@ def constraint_match_set(self) -> set[ConstraintMatch[Score_]]: def indicted_object(self) -> Any: return unwrap_python_like_object(self._delegate.getIndictedObject()) - def get_justification_list(self, justification_type=None) -> list[Any]: + def get_justification_list(self, justification_type: Type[Justification_] = None) -> list[Justification_]: if justification_type is None: justification_list = self._delegate.getJustificationList() else: @@ -250,7 +256,7 @@ def solution(self) -> Solution_: def summary(self) -> str: return self._delegate.getSummary() - def get_justification_list(self, justification_type=None) -> list[Any]: + def get_justification_list(self, justification_type: Type[Justification_] = None) -> list[Justification_]: if justification_type is None: justification_list = self._delegate.getJustificationList() else: @@ -275,7 +281,7 @@ def score(self) -> Score_: return self._delegate.score() @property - def justification(self) -> Any: + def justification(self) -> ConstraintJustification: return _unwrap_justification(self._delegate.justification()) @@ -343,5 +349,5 @@ def constraint_analyses(self) -> list[ConstraintAnalysis]: __all__ = ['SolutionManager', 'ScoreExplanation', 'ConstraintRef', 'ConstraintMatch', 'ConstraintMatchTotal', - 'DefaultConstraintJustification', 'Indictment', + 'ConstraintJustification', 'DefaultConstraintJustification', 'Indictment', 'ScoreAnalysis', 'ConstraintAnalysis', 'MatchAnalysis'] diff --git a/timefold-solver-python-core/src/main/python/constraint/_constraint_builder.py b/timefold-solver-python-core/src/main/python/constraint/_constraint_builder.py index df63e0c..2549ae8 100644 --- a/timefold-solver-python-core/src/main/python/constraint/_constraint_builder.py +++ b/timefold-solver-python-core/src/main/python/constraint/_constraint_builder.py @@ -1,4 +1,5 @@ from ._function_translator import function_cast +import timefold.solver.api as api from typing import TypeVar, Callable, Generic, Collection, Any, TYPE_CHECKING, Type if TYPE_CHECKING: @@ -31,10 +32,11 @@ def indict_with(self, indictment_function: Callable[[A], Collection]) -> 'UniCon return UniConstraintBuilder(self.delegate.indictWith( function_cast(indictment_function, self.a_type)), self.a_type) - def justify_with(self, justification_function: Callable[[A, ScoreType], Any]) -> \ + def justify_with(self, justification_function: Callable[[A, ScoreType], 'api.ConstraintJustification']) -> \ 'UniConstraintBuilder[A, ScoreType]': + from ai.timefold.solver.core.api.score import Score return UniConstraintBuilder(self.delegate.justifyWith( - function_cast(justification_function, self.a_type)), self.a_type) + function_cast(justification_function, self.a_type, Score)), self.a_type) def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint': if constraint_name is None: @@ -58,10 +60,11 @@ def indict_with(self, indictment_function: Callable[[A, B], Collection]) -> 'BiC return BiConstraintBuilder(self.delegate.indictWith( function_cast(indictment_function, self.a_type, self.b_type)), self.a_type, self.b_type) - def justify_with(self, justification_function: Callable[[A, B, ScoreType], Any]) -> \ + def justify_with(self, justification_function: Callable[[A, B, ScoreType], 'api.ConstraintJustification']) -> \ 'BiConstraintBuilder[A, B, ScoreType]': + from ai.timefold.solver.core.api.score import Score return BiConstraintBuilder(self.delegate.justifyWith( - function_cast(justification_function, self.a_type, self.b_type)), self.a_type, self.b_type) + function_cast(justification_function, self.a_type, self.b_type, Score)), self.a_type, self.b_type) def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint': if constraint_name is None: @@ -89,11 +92,12 @@ def indict_with(self, indictment_function: Callable[[A, B, C], Collection]) -> \ function_cast(indictment_function, self.a_type, self.b_type, self.c_type)), self.a_type, self.b_type, self.c_type) - def justify_with(self, justification_function: Callable[[A, B, C, ScoreType], Any]) -> \ + def justify_with(self, justification_function: Callable[[A, B, C, ScoreType], 'api.ConstraintJustification']) -> \ 'TriConstraintBuilder[A, B, C, ScoreType]': + from ai.timefold.solver.core.api.score import Score return TriConstraintBuilder(self.delegate.justifyWith( - function_cast(justification_function, self.a_type, self.b_type, self.c_type)), self.a_type, self.b_type, - self.c_type) + function_cast(justification_function, self.a_type, self.b_type, self.c_type, Score)), + self.a_type, self.b_type, self.c_type) def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint': if constraint_name is None: @@ -123,10 +127,11 @@ def indict_with(self, indictment_function: Callable[[A, B, C, D], Collection]) - function_cast(indictment_function, self.a_type, self.b_type, self.c_type, self.d_type)), self.a_type, self.b_type, self.c_type, self.d_type) - def justify_with(self, justification_function: Callable[[A, B, C, D, ScoreType], Any]) -> \ - 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + def justify_with(self, justification_function: Callable[[A, B, C, D, ScoreType], 'api.ConstraintJustification']) \ + -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': + from ai.timefold.solver.core.api.score import Score return QuadConstraintBuilder(self.delegate.justifyWith( - function_cast(justification_function, self.a_type, self.b_type, self.c_type, self.d_type)), + function_cast(justification_function, self.a_type, self.b_type, self.c_type, self.d_type, Score)), self.a_type, self.b_type, self.c_type, self.d_type) def as_constraint(self, constraint_package_or_name: str, constraint_name: str = None) -> '_JavaConstraint':