diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/AnnotationMetadata.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/AnnotationMetadata.java index 240d7ff2..f3593da0 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/AnnotationMetadata.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/AnnotationMetadata.java @@ -1,7 +1,11 @@ package ai.timefold.jpyinterpreter; import java.lang.annotation.Annotation; +import java.lang.annotation.Repeatable; import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import org.objectweb.asm.AnnotationVisitor; @@ -23,6 +27,30 @@ public void addAnnotationTo(MethodVisitor methodVisitor) { visitAnnotation(methodVisitor.visitAnnotation(Type.getDescriptor(annotationType), true)); } + public static List getAnnotationListWithoutRepeatable(List metadata) { + List out = new ArrayList<>(); + Map, List> repeatableAnnotationMap = new LinkedHashMap<>(); + for (AnnotationMetadata annotation : metadata) { + Repeatable repeatable = annotation.annotationType().getAnnotation(Repeatable.class); + if (repeatable == null) { + out.add(annotation); + continue; + } + var annotationContainer = repeatable.value(); + repeatableAnnotationMap.computeIfAbsent(annotationContainer, + ignored -> new ArrayList<>()).add(annotation); + } + for (var entry : repeatableAnnotationMap.entrySet()) { + out.add(new AnnotationMetadata(entry.getKey(), + Map.of("value", entry.getValue().toArray(AnnotationMetadata[]::new)))); + } + return out; + } + + public static Type getValueAsType(String className) { + return Type.getType("L" + className.replace('.', '/') + ";"); + } + private void visitAnnotation(AnnotationVisitor annotationVisitor) { for (var entry : annotationValueMap.entrySet()) { var annotationAttributeName = entry.getKey(); @@ -42,8 +70,8 @@ private void visitAnnotationAttribute(AnnotationVisitor annotationVisitor, Strin return; } - if (attributeValue instanceof Class clazz) { - annotationVisitor.visit(attributeName, Type.getType(clazz)); + if (attributeValue instanceof Type type) { + annotationVisitor.visit(attributeName, type); return; } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java index 0984d657..dadaa4dc 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java @@ -46,6 +46,7 @@ import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple; import ai.timefold.jpyinterpreter.types.wrappers.JavaObjectWrapper; import ai.timefold.jpyinterpreter.types.wrappers.OpaquePythonReference; +import ai.timefold.jpyinterpreter.util.JavaPythonClassWriter; import ai.timefold.jpyinterpreter.util.arguments.ArgumentSpec; import org.objectweb.asm.ClassWriter; @@ -122,7 +123,8 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp for (Class javaInterface : pythonCompiledClass.javaInterfaces) { javaInterfaceImplementorSet.add( - new DelegatingInterfaceImplementor(internalClassName, javaInterface, instanceMethodNameToMethodDescriptor)); + new DelegatingInterfaceImplementor(internalClassName, javaInterface, + instanceMethodNameToMethodDescriptor)); } if (pythonCompiledClass.superclassList.isEmpty()) { @@ -173,14 +175,14 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp interfaces[i] = Type.getInternalName(nonObjectInterfaceImplementors.get(i).getInterfaceClass()); } - ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); + ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); classWriter.visit(Opcodes.V11, Modifier.PUBLIC, internalClassName, null, superClassType.getJavaTypeInternalName(), interfaces); classWriter.visitSource(pythonCompiledClass.moduleFilePath, null); - for (var annotation : pythonCompiledClass.annotations) { + for (var annotation : AnnotationMetadata.getAnnotationListWithoutRepeatable(pythonCompiledClass.annotations)) { annotation.addAnnotationTo(classWriter); } @@ -208,6 +210,7 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp var typeHint = pythonCompiledClass.typeAnnotations.getOrDefault(attributeName, TypeHint.withoutAnnotations(BuiltinTypes.BASE_TYPE)); PythonLikeType type = typeHint.type(); + PythonLikeType javaGetterType = typeHint.javaGetterType(); if (type == null) { // null might be in __annotations__ type = BuiltinTypes.BASE_TYPE; } @@ -215,12 +218,15 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp attributeNameToTypeMap.put(attributeName, type); FieldVisitor fieldVisitor; String javaFieldTypeDescriptor; + String getterTypeDescriptor; String signature = null; boolean isJavaType; if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) { javaFieldTypeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType()); - fieldVisitor = classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor, - null, null); + getterTypeDescriptor = javaFieldTypeDescriptor; + fieldVisitor = + classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor, + null, null); isJavaType = true; } else { if (typeHint.genericArgs() != null) { @@ -229,16 +235,21 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp signature = signatureWriter.toString(); } javaFieldTypeDescriptor = 'L' + type.getJavaTypeInternalName() + ';'; - fieldVisitor = classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor, - signature, null); + getterTypeDescriptor = javaGetterType.getJavaTypeDescriptor(); + fieldVisitor = + classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor, + signature, null); isJavaType = false; } fieldVisitor.visitEnd(); + createJavaGetterSetter(classWriter, preparedClassInfo, attributeName, Type.getType(javaFieldTypeDescriptor), + Type.getType(getterTypeDescriptor), signature, typeHint); + FieldDescriptor fieldDescriptor = new FieldDescriptor(attributeName, getJavaFieldName(attributeName), internalClassName, javaFieldTypeDescriptor, type, true, isJavaType); @@ -344,7 +355,8 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp pythonLikeType.$setAttribute("__module__", PythonString.valueOf(pythonCompiledClass.module)); PythonLikeDict annotations = new PythonLikeDict(); - pythonCompiledClass.typeAnnotations.forEach((name, type) -> annotations.put(PythonString.valueOf(name), type.type())); + pythonCompiledClass.typeAnnotations + .forEach((name, type) -> annotations.put(PythonString.valueOf(name), type.type())); pythonLikeType.$setAttribute("__annotations__", annotations); PythonLikeTuple mro = new PythonLikeTuple(); @@ -552,7 +564,7 @@ private static Class createPythonWrapperMethod(String methodName, PythonCompi String className = maybeClassName; String internalClassName = className.replace('.', '/'); - ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); + ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); classWriter.visit(Opcodes.V11, Modifier.PUBLIC, internalClassName, null, Type.getInternalName(Object.class), new String[] { interfaceDeclaration.interfaceName }); @@ -663,7 +675,7 @@ private static PythonLikeFunction createConstructor(String classInternalName, String constructorClassName = maybeClassName; String constructorInternalClassName = constructorClassName.replace('.', '/'); - ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); + ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); classWriter.visit(Opcodes.V11, Modifier.PUBLIC, constructorInternalClassName, null, Type.getInternalName(Object.class), new String[] { Type.getInternalName(PythonLikeFunction.class) @@ -775,24 +787,24 @@ private static PythonLikeFunction createConstructor(String classInternalName, private static void createJavaGetterSetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, - String attributeName, Type attributeType, + String attributeName, Type attributeType, Type getterType, String signature, TypeHint typeHint) { - createJavaGetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint); - createJavaSetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint); + createJavaGetter(classWriter, preparedClassInfo, attributeName, attributeType, getterType, signature, typeHint); + createJavaSetter(classWriter, preparedClassInfo, attributeName, attributeType, getterType, signature, typeHint); } private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName, - Type attributeType, String signature, TypeHint typeHint) { + Type attributeType, Type getterType, String signature, TypeHint typeHint) { var getterName = "get" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1); - if (signature != null) { + if (signature != null && Objects.equals(attributeType, getterType)) { signature = "()" + signature; } - var getterVisitor = classWriter.visitMethod(Modifier.PUBLIC, getterName, Type.getMethodDescriptor(attributeType), + var getterVisitor = classWriter.visitMethod(Modifier.PUBLIC, getterName, Type.getMethodDescriptor(getterType), signature, null); var maxStack = 1; - for (var annotation : typeHint.annotationList()) { + for (var annotation : AnnotationMetadata.getAnnotationListWithoutRepeatable(typeHint.annotationList())) { annotation.addAnnotationTo(getterVisitor); } @@ -813,19 +825,22 @@ private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo // If branch is taken, stack is field // If branch is not taken, stack is null } + if (!Objects.equals(attributeType, getterType)) { + getterVisitor.visitTypeInsn(Opcodes.CHECKCAST, getterType.getInternalName()); + } getterVisitor.visitInsn(Opcodes.ARETURN); getterVisitor.visitMaxs(maxStack, 0); getterVisitor.visitEnd(); } private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName, - Type attributeType, String signature, TypeHint typeHint) { + Type attributeType, Type setterType, String signature, TypeHint typeHint) { var setterName = "set" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1); - if (signature != null) { + if (signature != null && Objects.equals(attributeType, setterType)) { signature = "(" + signature + ")V"; } var setterVisitor = classWriter.visitMethod(Modifier.PUBLIC, setterName, Type.getMethodDescriptor(Type.VOID_TYPE, - attributeType), + setterType), signature, null); var maxStack = 2; setterVisitor.visitCode(); @@ -845,6 +860,9 @@ private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo // If branch is taken, stack is (non-null instance) // If branch is not taken, stack is None } + if (!Objects.equals(attributeType, setterType)) { + setterVisitor.visitTypeInsn(Opcodes.CHECKCAST, attributeType.getInternalName()); + } setterVisitor.visitFieldInsn(Opcodes.PUTFIELD, preparedClassInfo.classInternalName, attributeName, attributeType.getDescriptor()); setterVisitor.visitInsn(Opcodes.RETURN); @@ -855,7 +873,7 @@ private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo private static void addAnnotationsToMethod(PythonCompiledFunction function, MethodVisitor methodVisitor) { var returnTypeHint = function.typeAnnotations.get("return"); if (returnTypeHint != null) { - for (var annotation : returnTypeHint.annotationList()) { + for (var annotation : AnnotationMetadata.getAnnotationListWithoutRepeatable(returnTypeHint.annotationList())) { annotation.addAnnotationTo(methodVisitor); } } @@ -1444,7 +1462,7 @@ public static InterfaceDeclaration createInterfaceForFunctionSignature(FunctionS String internalClassName = className.replace('.', '/'); - ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); + ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); classWriter.visit(Opcodes.V11, Modifier.PUBLIC | Modifier.INTERFACE | Modifier.ABSTRACT, internalClassName, null, Type.getInternalName(Object.class), null); diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledFunction.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledFunction.java index 35582847..f06a6bb3 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledFunction.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledFunction.java @@ -208,12 +208,16 @@ private static Class getParameterJavaClass(List parameter return (Class) parameterTypeList.get(variableIndex).getJavaClassOrDefault(PythonLikeObject.class); } + private static String getParameterJavaClassName(List parameterTypeList, int variableIndex) { + return parameterTypeList.get(variableIndex).getJavaTypeInternalName(); + } + @SuppressWarnings({ "unchecked", "rawtypes" }) public BiFunction> getArgumentSpecMapper() { return (defaultPositionalArguments, defaultKeywordArguments) -> { ArgumentSpec out = ArgumentSpec.forFunctionReturning(qualifiedName, getReturnType() - .map(type -> (Class) type.getJavaClassOrDefault(PythonLikeObject.class)) - .orElse(PythonLikeObject.class)); + .map(PythonLikeType::getJavaTypeInternalName) + .orElse(PythonLikeObject.class.getName())); int variableIndex = 0; int defaultPositionalStartIndex = co_argcount - defaultPositionalArguments.size(); @@ -226,23 +230,23 @@ public BiFunction= defaultPositionalStartIndex) { out = out.addPositionalOnlyArgument(co_varnames.get(variableIndex), - getParameterJavaClass(parameterTypeList, variableIndex), + getParameterJavaClassName(parameterTypeList, variableIndex), defaultPositionalArguments.get( variableIndex - defaultPositionalStartIndex)); } else { out = out.addPositionalOnlyArgument(co_varnames.get(variableIndex), - getParameterJavaClass(parameterTypeList, variableIndex)); + getParameterJavaClassName(parameterTypeList, variableIndex)); } } for (; variableIndex < co_argcount; variableIndex++) { if (variableIndex >= defaultPositionalStartIndex) { out = out.addArgument(co_varnames.get(variableIndex), - getParameterJavaClass(parameterTypeList, variableIndex), + getParameterJavaClassName(parameterTypeList, variableIndex), defaultPositionalArguments.get(variableIndex - defaultPositionalStartIndex)); } else { out = out.addArgument(co_varnames.get(variableIndex), - getParameterJavaClass(parameterTypeList, variableIndex)); + getParameterJavaClassName(parameterTypeList, variableIndex)); } } @@ -251,11 +255,11 @@ public BiFunction createDefaultArgumentFor(MethodDescriptor methodDescriptor, + public static String createDefaultArgumentFor(MethodDescriptor methodDescriptor, List defaultArgumentList, Map argumentNameToIndexMap, Optional extraPositionalArgumentsVariableIndex, Optional extraKeywordArgumentsVariableIndex, @@ -169,26 +169,40 @@ public static Class createDefaultArgumentFor(MethodDescriptor methodDescripto createAddArgumentMethod(classWriter, internalClassName, methodDescriptor, argumentNameToIndexMap, extraPositionalArgumentsVariableIndex, extraKeywordArgumentsVariableIndex, argumentSpec); + // clinit to set ArgumentSpec, as class cannot be loaded if it contains + // yet to be compiled forward references + methodVisitor = classWriter.visitMethod(Modifier.PUBLIC | Modifier.STATIC, "", + Type.getMethodDescriptor(Type.VOID_TYPE), null, null); + methodVisitor.visitCode(); + + argumentSpec.loadArgumentSpec(methodVisitor); + methodVisitor.visitInsn(Opcodes.DUP); + methodVisitor.visitFieldInsn(Opcodes.PUTSTATIC, internalClassName, + ARGUMENT_SPEC_STATIC_FIELD_NAME, Type.getDescriptor(ArgumentSpec.class)); + + for (int i = 0; i < defaultArgumentList.size(); i++) { + methodVisitor.visitInsn(Opcodes.DUP); + methodVisitor.visitLdcInsn(defaultStart + i); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(ArgumentSpec.class), + "getDefaultValue", Type.getMethodDescriptor(Type.getType(Object.class), Type.INT_TYPE), + false); + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, + methodDescriptor.getParameterTypes()[defaultStart + i].getInternalName()); + String fieldName = getConstantName(i); + methodVisitor.visitFieldInsn(Opcodes.PUTSTATIC, internalClassName, + fieldName, methodDescriptor.getParameterTypes()[defaultStart + i].getDescriptor()); + } + + methodVisitor.visitInsn(Opcodes.RETURN); + + methodVisitor.visitMaxs(-1, -1); + methodVisitor.visitEnd(); + classWriter.visitEnd(); PythonBytecodeToJavaBytecodeTranslator.writeClassOutput(BuiltinTypes.classNameToBytecode, className, classWriter.toByteArray()); - try { - Class compiledClass = BuiltinTypes.asmClassLoader.loadClass(className); - compiledClass.getField(ARGUMENT_SPEC_STATIC_FIELD_NAME).set(null, argumentSpec); - for (int i = 0; i < defaultArgumentList.size(); i++) { - PythonLikeObject value = defaultArgumentList.get(i); - String fieldName = getConstantName(i); - compiledClass.getField(fieldName).set(null, value); - } - return compiledClass; - } catch (ClassNotFoundException e) { - throw new IllegalStateException("Impossible State: Unable to load generated class (" + - className + ") despite it being just generated.", e); - } catch (NoSuchFieldException | IllegalAccessException e) { - throw new IllegalStateException("Impossible State: Unable to set field in generated class (" + - className + ").", e); - } + return internalClassName; } /** diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonFunctionSignature.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonFunctionSignature.java index ef6a056e..8fab72e0 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonFunctionSignature.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonFunctionSignature.java @@ -30,7 +30,7 @@ public class PythonFunctionSignature { private final Optional extraPositionalArgumentsVariableIndex; private final Optional extraKeywordArgumentsVariableIndex; - private final Class defaultArgumentHolderClass; + private final String defaultArgumentHolderClassInternalName; private final ArgumentSpec argumentSpec; private final boolean isFromArgumentSpec; @@ -78,7 +78,7 @@ public PythonFunctionSignature(MethodDescriptor methodDescriptor, this.extraKeywordArgumentsVariableIndex = Optional.empty(); isFromArgumentSpec = false; argumentSpec = computeArgumentSpec(); - defaultArgumentHolderClass = PythonDefaultArgumentImplementor.createDefaultArgumentFor(methodDescriptor, + defaultArgumentHolderClassInternalName = PythonDefaultArgumentImplementor.createDefaultArgumentFor(methodDescriptor, defaultArgumentList, keywordToArgumentIndexMap, getExtraPositionalArgumentsVariableIndex(), getExtraKeywordArgumentsVariableIndex(), getArgumentSpec()); @@ -97,7 +97,7 @@ public PythonFunctionSignature(MethodDescriptor methodDescriptor, this.extraKeywordArgumentsVariableIndex = Optional.empty(); isFromArgumentSpec = false; argumentSpec = computeArgumentSpec(); - defaultArgumentHolderClass = PythonDefaultArgumentImplementor.createDefaultArgumentFor(methodDescriptor, + defaultArgumentHolderClassInternalName = PythonDefaultArgumentImplementor.createDefaultArgumentFor(methodDescriptor, defaultArgumentList, keywordToArgumentIndexMap, getExtraPositionalArgumentsVariableIndex(), getExtraKeywordArgumentsVariableIndex(), getArgumentSpec()); } @@ -117,7 +117,7 @@ public PythonFunctionSignature(MethodDescriptor methodDescriptor, this.extraKeywordArgumentsVariableIndex = extraKeywordArgumentsVariableIndex; isFromArgumentSpec = false; argumentSpec = computeArgumentSpec(); - defaultArgumentHolderClass = PythonDefaultArgumentImplementor.createDefaultArgumentFor(methodDescriptor, + defaultArgumentHolderClassInternalName = PythonDefaultArgumentImplementor.createDefaultArgumentFor(methodDescriptor, defaultArgumentList, keywordToArgumentIndexMap, extraPositionalArgumentsVariableIndex, extraKeywordArgumentsVariableIndex, getArgumentSpec()); } @@ -138,81 +138,77 @@ public PythonFunctionSignature(MethodDescriptor methodDescriptor, this.extraKeywordArgumentsVariableIndex = extraKeywordArgumentsVariableIndex; this.argumentSpec = argumentSpec; isFromArgumentSpec = true; - defaultArgumentHolderClass = PythonDefaultArgumentImplementor.createDefaultArgumentFor(methodDescriptor, + defaultArgumentHolderClassInternalName = PythonDefaultArgumentImplementor.createDefaultArgumentFor(methodDescriptor, defaultArgumentList, keywordToArgumentIndexMap, extraPositionalArgumentsVariableIndex, extraKeywordArgumentsVariableIndex, argumentSpec); } private ArgumentSpec computeArgumentSpec() { - try { - ArgumentSpec argumentSpec = ArgumentSpec.forFunctionReturning(getMethodDescriptor().getMethodName(), - (Class) getReturnType().getJavaClass()); - for (int i = 0; i < getParameterTypes().length - getDefaultArgumentList().size(); i++) { - if (getExtraPositionalArgumentsVariableIndex().isPresent() - && getExtraPositionalArgumentsVariableIndex().get() == i) { - continue; - } - - if (getExtraKeywordArgumentsVariableIndex().isPresent() && getExtraKeywordArgumentsVariableIndex().get() == i) { - continue; - } - - final int argIndex = i; - Optional argumentName = getKeywordToArgumentIndexMap().entrySet() - .stream().filter(e -> e.getValue().equals(argIndex)) - .map(Map.Entry::getKey) - .findAny(); - - if (argumentName.isEmpty()) { - argumentSpec = argumentSpec.addArgument("$arg" + i, - (Class) getParameterTypes()[i].getJavaClass()); - } else { - argumentSpec = argumentSpec.addArgument(argumentName.get(), - (Class) getParameterTypes()[i].getJavaClass()); - } + ArgumentSpec argumentSpec = ArgumentSpec.forFunctionReturning(getMethodDescriptor().getMethodName(), + getReturnType().getJavaTypeInternalName()); + for (int i = 0; i < getParameterTypes().length - getDefaultArgumentList().size(); i++) { + if (getExtraPositionalArgumentsVariableIndex().isPresent() + && getExtraPositionalArgumentsVariableIndex().get() == i) { + continue; + } + + if (getExtraKeywordArgumentsVariableIndex().isPresent() && getExtraKeywordArgumentsVariableIndex().get() == i) { + continue; } - for (int i = getParameterTypes().length - getDefaultArgumentList().size(); i < getParameterTypes().length; i++) { - if (getExtraPositionalArgumentsVariableIndex().isPresent() - && getExtraPositionalArgumentsVariableIndex().get() == i) { - continue; - } - - if (getExtraKeywordArgumentsVariableIndex().isPresent() && getExtraKeywordArgumentsVariableIndex().get() == i) { - continue; - } - - PythonLikeObject defaultValue = - getDefaultArgumentList().get(getDefaultArgumentList().size() - (getParameterTypes().length - i)); - - final int argIndex = i; - Optional argumentName = getKeywordToArgumentIndexMap().entrySet() - .stream().filter(e -> e.getValue().equals(argIndex)) - .map(Map.Entry::getKey) - .findAny(); - - if (argumentName.isEmpty()) { - argumentSpec = argumentSpec.addArgument("$arg" + i, - (Class) getParameterTypes()[i].getJavaClass(), - defaultValue); - } else { - argumentSpec = argumentSpec.addArgument(argumentName.get(), - (Class) getParameterTypes()[i].getJavaClass(), - defaultValue); - } + final int argIndex = i; + Optional argumentName = getKeywordToArgumentIndexMap().entrySet() + .stream().filter(e -> e.getValue().equals(argIndex)) + .map(Map.Entry::getKey) + .findAny(); + + if (argumentName.isEmpty()) { + argumentSpec = argumentSpec.addArgument("$arg" + i, + getParameterTypes()[i].getJavaTypeInternalName()); + } else { + argumentSpec = argumentSpec.addArgument(argumentName.get(), + getParameterTypes()[i].getJavaTypeInternalName()); + } + } + + for (int i = getParameterTypes().length - getDefaultArgumentList().size(); i < getParameterTypes().length; i++) { + if (getExtraPositionalArgumentsVariableIndex().isPresent() + && getExtraPositionalArgumentsVariableIndex().get() == i) { + continue; } - if (getExtraPositionalArgumentsVariableIndex().isPresent()) { - argumentSpec = argumentSpec.addExtraPositionalVarArgument("*vargs"); + if (getExtraKeywordArgumentsVariableIndex().isPresent() && getExtraKeywordArgumentsVariableIndex().get() == i) { + continue; } - if (getExtraKeywordArgumentsVariableIndex().isPresent()) { - argumentSpec = argumentSpec.addExtraKeywordVarArgument("**kwargs"); + PythonLikeObject defaultValue = + getDefaultArgumentList().get(getDefaultArgumentList().size() - (getParameterTypes().length - i)); + + final int argIndex = i; + Optional argumentName = getKeywordToArgumentIndexMap().entrySet() + .stream().filter(e -> e.getValue().equals(argIndex)) + .map(Map.Entry::getKey) + .findAny(); + + if (argumentName.isEmpty()) { + argumentSpec = argumentSpec.addArgument("$arg" + i, + getParameterTypes()[i].getJavaTypeInternalName(), + defaultValue); + } else { + argumentSpec = argumentSpec.addArgument(argumentName.get(), + getParameterTypes()[i].getJavaTypeInternalName(), + defaultValue); } - return argumentSpec; - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); } + + if (getExtraPositionalArgumentsVariableIndex().isPresent()) { + argumentSpec = argumentSpec.addExtraPositionalVarArgument("*vargs"); + } + + if (getExtraKeywordArgumentsVariableIndex().isPresent()) { + argumentSpec = argumentSpec.addExtraKeywordVarArgument("**kwargs"); + } + return argumentSpec; } public ArgumentSpec getArgumentSpec() { @@ -251,8 +247,8 @@ public Optional getExtraKeywordArgumentsVariableIndex() { return extraKeywordArgumentsVariableIndex; } - public Class getDefaultArgumentHolderClass() { - return defaultArgumentHolderClass; + public String getDefaultArgumentHolderClassInternalName() { + return defaultArgumentHolderClassInternalName; } public boolean isVirtualMethod() { diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonOverloadImplementor.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonOverloadImplementor.java index 73b76cbb..297bcf1b 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonOverloadImplementor.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonOverloadImplementor.java @@ -475,7 +475,7 @@ private static void createGenericDispatch(MethodVisitor methodVisitor, methodVisitor.visitVarInsn(Opcodes.ALOAD, 1); } methodVisitor.visitVarInsn(Opcodes.ALOAD, 2); - KnownCallImplementor.callUnpackListAndMap(functionSignature.getDefaultArgumentHolderClass(), + KnownCallImplementor.callUnpackListAndMap(functionSignature.getDefaultArgumentHolderClassInternalName(), functionSignature.getMethodDescriptor(), methodVisitor); methodVisitor.visitInsn(Opcodes.ARETURN); } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/TypeHint.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/TypeHint.java index ec67d6a5..ba270ac6 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/TypeHint.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/TypeHint.java @@ -6,20 +6,25 @@ import ai.timefold.jpyinterpreter.types.PythonLikeType; -public record TypeHint(PythonLikeType type, List annotationList, TypeHint[] genericArgs) { +public record TypeHint(PythonLikeType type, List annotationList, TypeHint[] genericArgs, + PythonLikeType javaGetterType) { public TypeHint { annotationList = Collections.unmodifiableList(annotationList); } public TypeHint(PythonLikeType type, List annotationList) { - this(type, annotationList, null); + this(type, annotationList, null, type); + } + + public TypeHint(PythonLikeType type, List annotationList, PythonLikeType javaGetterType) { + this(type, annotationList, null, javaGetterType); } public TypeHint addAnnotations(List addedAnnotations) { List combinedAnnotations = new ArrayList<>(annotationList.size() + addedAnnotations.size()); combinedAnnotations.addAll(annotationList); combinedAnnotations.addAll(addedAnnotations); - return new TypeHint(type, combinedAnnotations, genericArgs); + return new TypeHint(type, combinedAnnotations, genericArgs, javaGetterType); } public static TypeHint withoutAnnotations(PythonLikeType type) { diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/KnownCallImplementor.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/KnownCallImplementor.java index 338f8cb0..b2ad3179 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/KnownCallImplementor.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/KnownCallImplementor.java @@ -113,14 +113,15 @@ public static void callMethod(PythonFunctionSignature pythonFunctionSignature, M // Now load and typecheck the local variables for (int i = 0; i < Math.min(specPositionalArgumentCount, argumentCount); i++) { localVariableHelper.readTemp(methodVisitor, Type.getType(PythonLikeObject.class), argumentLocals[i]); - methodVisitor.visitLdcInsn(Type.getType(pythonFunctionSignature.getArgumentSpec().getArgumentType(i))); + methodVisitor.visitLdcInsn( + Type.getType("L" + pythonFunctionSignature.getArgumentSpec().getArgumentTypeInternalName(i) + ";")); methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(JavaPythonTypeConversionImplementor.class), "coerceToType", Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(PythonLikeObject.class), Type.getType(Class.class)), false); methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, - Type.getInternalName(pythonFunctionSignature.getArgumentSpec().getArgumentType(i))); + pythonFunctionSignature.getArgumentSpec().getArgumentTypeInternalName(i)); } // Load any arguments missing values @@ -129,9 +130,9 @@ public static void callMethod(PythonFunctionSignature pythonFunctionSignature, M methodVisitor.visitInsn(Opcodes.ACONST_NULL); } else { methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, - Type.getInternalName(pythonFunctionSignature.getDefaultArgumentHolderClass()), + pythonFunctionSignature.getDefaultArgumentHolderClassInternalName(), PythonDefaultArgumentImplementor.getConstantName(i), - Type.getDescriptor(pythonFunctionSignature.getArgumentSpec().getArgumentType(i))); + "L" + pythonFunctionSignature.getArgumentSpec().getArgumentTypeInternalName(i) + ";"); } } @@ -245,9 +246,9 @@ public static void callPython311andAbove(PythonFunctionSignature pythonFunctionS methodVisitor.visitInsn(Opcodes.ACONST_NULL); } else { methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, - Type.getInternalName(pythonFunctionSignature.getDefaultArgumentHolderClass()), + pythonFunctionSignature.getDefaultArgumentHolderClassInternalName(), PythonDefaultArgumentImplementor.getConstantName(argumentIndex - defaultOffset), - Type.getDescriptor(pythonFunctionSignature.getArgumentSpec().getArgumentType(argumentIndex))); + "L" + pythonFunctionSignature.getArgumentSpec().getArgumentTypeInternalName(argumentIndex) + ";"); } localVariableHelper.writeTemp(methodVisitor, Type.getType(PythonLikeObject.class), argumentLocals[argumentIndex]); @@ -285,14 +286,15 @@ public static void callPython311andAbove(PythonFunctionSignature pythonFunctionS // Load arguments in proper order and typecast them for (int i = 0; i < specTotalArgumentCount; i++) { localVariableHelper.readTemp(methodVisitor, Type.getType(PythonLikeObject.class), argumentLocals[i]); - methodVisitor.visitLdcInsn(Type.getType(pythonFunctionSignature.getArgumentSpec().getArgumentType(i))); + methodVisitor.visitLdcInsn( + Type.getType("L" + pythonFunctionSignature.getArgumentSpec().getArgumentTypeInternalName(i) + ";")); methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(JavaPythonTypeConversionImplementor.class), "coerceToType", Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(PythonLikeObject.class), Type.getType(Class.class)), false); methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, - Type.getInternalName(pythonFunctionSignature.getArgumentSpec().getArgumentType(i))); + pythonFunctionSignature.getArgumentSpec().getArgumentTypeInternalName(i)); } pythonFunctionSignature.getMethodDescriptor().callMethod(methodVisitor); @@ -337,7 +339,7 @@ private static void callWithKeywords(PythonFunctionSignature pythonFunctionSigna Type[] descriptorParameterTypes = pythonFunctionSignature.getMethodDescriptor().getParameterTypes(); if (argumentCount < descriptorParameterTypes.length - && pythonFunctionSignature.getDefaultArgumentHolderClass() == null) { + && pythonFunctionSignature.getDefaultArgumentHolderClassInternalName() == null) { throw new IllegalStateException( "Cannot call " + pythonFunctionSignature + " because there are not enough arguments"); } @@ -355,7 +357,7 @@ private static void callWithKeywords(PythonFunctionSignature pythonFunctionSigna } // TOS is a tuple of keys - methodVisitor.visitTypeInsn(Opcodes.NEW, Type.getInternalName(pythonFunctionSignature.getDefaultArgumentHolderClass())); + methodVisitor.visitTypeInsn(Opcodes.NEW, pythonFunctionSignature.getDefaultArgumentHolderClassInternalName()); methodVisitor.visitInsn(Opcodes.DUP_X1); methodVisitor.visitInsn(Opcodes.SWAP); @@ -374,7 +376,7 @@ private static void callWithKeywords(PythonFunctionSignature pythonFunctionSigna // Stack is defaults (uninitialized), keys, positional arguments methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, - Type.getInternalName(pythonFunctionSignature.getDefaultArgumentHolderClass()), + pythonFunctionSignature.getDefaultArgumentHolderClassInternalName(), "", Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(PythonLikeTuple.class), Type.INT_TYPE), false); @@ -402,7 +404,7 @@ private static void callWithKeywords(PythonFunctionSignature pythonFunctionSigna methodVisitor.visitLabel(doneGettingType); } methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, - Type.getInternalName(pythonFunctionSignature.getDefaultArgumentHolderClass()), + pythonFunctionSignature.getDefaultArgumentHolderClassInternalName(), "addArgument", Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(PythonLikeObject.class)), false); } @@ -410,7 +412,7 @@ private static void callWithKeywords(PythonFunctionSignature pythonFunctionSigna for (int i = 0; i < descriptorParameterTypes.length; i++) { methodVisitor.visitInsn(Opcodes.DUP); methodVisitor.visitFieldInsn(Opcodes.GETFIELD, - Type.getInternalName(pythonFunctionSignature.getDefaultArgumentHolderClass()), + pythonFunctionSignature.getDefaultArgumentHolderClassInternalName(), PythonDefaultArgumentImplementor.getArgumentName(i), descriptorParameterTypes[i].getDescriptor()); methodVisitor.visitInsn(Opcodes.SWAP); @@ -420,7 +422,7 @@ private static void callWithKeywords(PythonFunctionSignature pythonFunctionSigna pythonFunctionSignature.getMethodDescriptor().callMethod(methodVisitor); } - public static void callUnpackListAndMap(Class defaultArgumentHolderClass, MethodDescriptor methodDescriptor, + public static void callUnpackListAndMap(String defaultArgumentHolderClassInternalName, MethodDescriptor methodDescriptor, MethodVisitor methodVisitor) { Type[] descriptorParameterTypes = methodDescriptor.getParameterTypes(); @@ -457,7 +459,7 @@ public static void callUnpackListAndMap(Class defaultArgumentHolderClass, Met // stack is bound-method, pos, keywords } - methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, Type.getInternalName(defaultArgumentHolderClass), + methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, defaultArgumentHolderClassInternalName, PythonDefaultArgumentImplementor.ARGUMENT_SPEC_STATIC_FIELD_NAME, Type.getDescriptor(ArgumentSpec.class)); diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/opcodes/descriptor/MetaOpDescriptor.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/opcodes/descriptor/MetaOpDescriptor.java index 96239f4a..9762327e 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/opcodes/descriptor/MetaOpDescriptor.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/opcodes/descriptor/MetaOpDescriptor.java @@ -29,19 +29,24 @@ public enum MetaOpDescriptor implements OpcodeDescriptor { CALL_INTRINSIC_1(UnaryIntrinsicFunction::lookup), // TODO - EXTENDED_ARG(null), + EXTENDED_ARG(ignored -> { + throw new UnsupportedOperationException("EXTENDED_ARG"); + }), /** * Pushes builtins.__build_class__() onto the stack. * It is later called by CALL_FUNCTION to construct a class. */ - LOAD_BUILD_CLASS(null), + LOAD_BUILD_CLASS(ignored -> { + throw new UnsupportedOperationException("LOAD_BUILD_CLASS"); + }), /** * Checks whether __annotations__ is defined in locals(), if not it is set up to an empty dict. This opcode is only * emitted if a class or module body contains variable annotations statically. + * TODO: Properly implement this */ - SETUP_ANNOTATIONS(null); + SETUP_ANNOTATIONS(NopOpcode::new); private final VersionMapping versionLookup; diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/opcodes/descriptor/VersionMapping.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/opcodes/descriptor/VersionMapping.java index 294c0b5c..49dfed4a 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/opcodes/descriptor/VersionMapping.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/opcodes/descriptor/VersionMapping.java @@ -1,6 +1,7 @@ package ai.timefold.jpyinterpreter.opcodes.descriptor; import java.util.NavigableMap; +import java.util.Objects; import java.util.TreeMap; import java.util.function.BiFunction; import java.util.function.Function; @@ -28,12 +29,12 @@ public static VersionMapping unimplemented() { public static VersionMapping constantMapping(Function mapper) { return new VersionMapping() - .map(PythonVersion.MINIMUM_PYTHON_VERSION, mapper); + .map(PythonVersion.MINIMUM_PYTHON_VERSION, Objects.requireNonNull(mapper)); } public static VersionMapping constantMapping(BiFunction mapper) { return new VersionMapping() - .map(PythonVersion.MINIMUM_PYTHON_VERSION, mapper); + .map(PythonVersion.MINIMUM_PYTHON_VERSION, Objects.requireNonNull(mapper)); } public VersionMapping map(PythonVersion version, Function mapper) { diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonString.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonString.java index c454eb73..5cf5a07c 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonString.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonString.java @@ -135,7 +135,7 @@ private static PythonLikeType registerMethods() throws NoSuchMethodException { BuiltinTypes.STRING_TYPE.addMethod("find", PythonString.class.getMethod("findSubstringIndex", PythonString.class, PythonInteger.class, PythonInteger.class)); - BuiltinTypes.STRING_TYPE.addMethod("format", ArgumentSpec.forFunctionReturning("format", PythonString.class) + BuiltinTypes.STRING_TYPE.addMethod("format", ArgumentSpec.forFunctionReturning("format", PythonString.class.getName()) .addExtraPositionalVarArgument("vargs") .addExtraKeywordVarArgument("kwargs") .asPythonFunctionSignature(PythonString.class.getMethod("format", List.class, Map.class))); diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonDate.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonDate.java index a7648207..12ae08bd 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonDate.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonDate.java @@ -62,10 +62,10 @@ public class PythonDate> extends AbstractPythonLikeObjec private static void registerMethods() throws NoSuchMethodException { // Constructor - DATE_TYPE.addConstructor(ArgumentSpec.forFunctionReturning("date", PythonDate.class) - .addArgument("year", PythonInteger.class) - .addArgument("month", PythonInteger.class) - .addArgument("day", PythonInteger.class) + DATE_TYPE.addConstructor(ArgumentSpec.forFunctionReturning("date", PythonDate.class.getName()) + .addArgument("year", PythonInteger.class.getName()) + .addArgument("month", PythonInteger.class.getName()) + .addArgument("day", PythonInteger.class.getName()) .asStaticPythonFunctionSignature(PythonDate.class.getMethod("of", PythonInteger.class, PythonInteger.class, PythonInteger.class))); // Unary Operators @@ -82,10 +82,10 @@ private static void registerMethods() throws NoSuchMethodException { // Methods DATE_TYPE.addMethod("replace", - ArgumentSpec.forFunctionReturning("replace", PythonDate.class) - .addNullableArgument("year", PythonInteger.class) - .addNullableArgument("month", PythonInteger.class) - .addNullableArgument("day", PythonInteger.class) + ArgumentSpec.forFunctionReturning("replace", PythonDate.class.getName()) + .addNullableArgument("year", PythonInteger.class.getName()) + .addNullableArgument("month", PythonInteger.class.getName()) + .addNullableArgument("day", PythonInteger.class.getName()) .asPythonFunctionSignature(PythonDate.class.getMethod("replace", PythonInteger.class, PythonInteger.class, PythonInteger.class))); DATE_TYPE.addMethod("timetuple", @@ -113,39 +113,39 @@ private static void registerMethods() throws NoSuchMethodException { // Class methods DATE_TYPE.addMethod("today", - ArgumentSpec.forFunctionReturning("today", PythonDate.class) - .addArgument("date_type", PythonLikeType.class) + ArgumentSpec.forFunctionReturning("today", PythonDate.class.getName()) + .addArgument("date_type", PythonLikeType.class.getName()) .asClassPythonFunctionSignature(PythonDate.class.getMethod("today", PythonLikeType.class))); DATE_TYPE.addMethod("fromtimestamp", - ArgumentSpec.forFunctionReturning("fromtimestamp", PythonDate.class) - .addArgument("date_type", PythonLikeType.class) - .addArgument("timestamp", PythonNumber.class) + ArgumentSpec.forFunctionReturning("fromtimestamp", PythonDate.class.getName()) + .addArgument("date_type", PythonLikeType.class.getName()) + .addArgument("timestamp", PythonNumber.class.getName()) .asClassPythonFunctionSignature(PythonDate.class.getMethod("from_timestamp", PythonLikeType.class, PythonNumber.class))); DATE_TYPE.addMethod("fromordinal", - ArgumentSpec.forFunctionReturning("fromordinal", PythonDate.class) - .addArgument("date_type", PythonLikeType.class) - .addArgument("ordinal", PythonInteger.class) + ArgumentSpec.forFunctionReturning("fromordinal", PythonDate.class.getName()) + .addArgument("date_type", PythonLikeType.class.getName()) + .addArgument("ordinal", PythonInteger.class.getName()) .asClassPythonFunctionSignature(PythonDate.class.getMethod("from_ordinal", PythonLikeType.class, PythonInteger.class))); DATE_TYPE.addMethod("fromisoformat", - ArgumentSpec.forFunctionReturning("fromisoformat", PythonDate.class) - .addArgument("date_type", PythonLikeType.class) - .addArgument("date_string", PythonString.class) + ArgumentSpec.forFunctionReturning("fromisoformat", PythonDate.class.getName()) + .addArgument("date_type", PythonLikeType.class.getName()) + .addArgument("date_string", PythonString.class.getName()) .asClassPythonFunctionSignature(PythonDate.class.getMethod("from_iso_format", PythonLikeType.class, PythonString.class))); DATE_TYPE.addMethod("fromisocalendar", - ArgumentSpec.forFunctionReturning("fromisocalendar", PythonDate.class) - .addArgument("date_type", PythonLikeType.class) - .addArgument("year", PythonInteger.class) - .addArgument("month", PythonInteger.class) - .addArgument("day", PythonInteger.class) + ArgumentSpec.forFunctionReturning("fromisocalendar", PythonDate.class.getName()) + .addArgument("date_type", PythonLikeType.class.getName()) + .addArgument("year", PythonInteger.class.getName()) + .addArgument("month", PythonInteger.class.getName()) + .addArgument("day", PythonInteger.class.getName()) .asClassPythonFunctionSignature(PythonDate.class.getMethod("from_iso_calendar", PythonLikeType.class, PythonInteger.class, PythonInteger.class, PythonInteger.class))); } diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonDateTime.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonDateTime.java index 076fa74c..b56dc1ba 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonDateTime.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonDateTime.java @@ -77,16 +77,16 @@ public class PythonDateTime extends PythonDate implements Planni private static void registerMethods() throws NoSuchMethodException { // Constructor - DATE_TIME_TYPE.addConstructor(ArgumentSpec.forFunctionReturning("datetime", PythonDateTime.class) - .addArgument("year", PythonInteger.class) - .addArgument("month", PythonInteger.class) - .addArgument("day", PythonInteger.class) - .addArgument("hour", PythonInteger.class, PythonInteger.ZERO) - .addArgument("minute", PythonInteger.class, PythonInteger.ZERO) - .addArgument("second", PythonInteger.class, PythonInteger.ZERO) - .addArgument("microsecond", PythonInteger.class, PythonInteger.ZERO) - .addArgument("tzinfo", PythonLikeObject.class, PythonNone.INSTANCE) - .addKeywordOnlyArgument("fold", PythonInteger.class, PythonInteger.ZERO) + DATE_TIME_TYPE.addConstructor(ArgumentSpec.forFunctionReturning("datetime", PythonDateTime.class.getName()) + .addArgument("year", PythonInteger.class.getName()) + .addArgument("month", PythonInteger.class.getName()) + .addArgument("day", PythonInteger.class.getName()) + .addArgument("hour", PythonInteger.class.getName(), PythonInteger.ZERO) + .addArgument("minute", PythonInteger.class.getName(), PythonInteger.ZERO) + .addArgument("second", PythonInteger.class.getName(), PythonInteger.ZERO) + .addArgument("microsecond", PythonInteger.class.getName(), PythonInteger.ZERO) + .addArgument("tzinfo", PythonLikeObject.class.getName(), PythonNone.INSTANCE) + .addKeywordOnlyArgument("fold", PythonInteger.class.getName(), PythonInteger.ZERO) .asPythonFunctionSignature( PythonDateTime.class.getMethod("of", PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonInteger.class, @@ -95,45 +95,45 @@ private static void registerMethods() throws NoSuchMethodException { // Class methods // Date handles today, DATE_TIME_TYPE.addMethod("now", - ArgumentSpec.forFunctionReturning("now", PythonDateTime.class) - .addArgument("datetime_type", PythonLikeType.class) - .addArgument("tzinfo", PythonLikeObject.class, PythonNone.INSTANCE) + ArgumentSpec.forFunctionReturning("now", PythonDateTime.class.getName()) + .addArgument("datetime_type", PythonLikeType.class.getName()) + .addArgument("tzinfo", PythonLikeObject.class.getName(), PythonNone.INSTANCE) .asClassPythonFunctionSignature( PythonDateTime.class.getMethod("now", PythonLikeType.class, PythonLikeObject.class))); DATE_TIME_TYPE.addMethod("utcnow", - ArgumentSpec.forFunctionReturning("now", PythonDateTime.class) - .addArgument("datetime_type", PythonLikeType.class) + ArgumentSpec.forFunctionReturning("now", PythonDateTime.class.getName()) + .addArgument("datetime_type", PythonLikeType.class.getName()) .asClassPythonFunctionSignature( PythonDateTime.class.getMethod("utc_now", PythonLikeType.class))); DATE_TIME_TYPE.addMethod("fromtimestamp", - ArgumentSpec.forFunctionReturning("fromtimestamp", PythonDate.class) - .addArgument("date_type", PythonLikeType.class) - .addArgument("timestamp", PythonNumber.class) - .addArgument("tzinfo", PythonLikeObject.class, PythonNone.INSTANCE) + ArgumentSpec.forFunctionReturning("fromtimestamp", PythonDate.class.getName()) + .addArgument("date_type", PythonLikeType.class.getName()) + .addArgument("timestamp", PythonNumber.class.getName()) + .addArgument("tzinfo", PythonLikeObject.class.getName(), PythonNone.INSTANCE) .asClassPythonFunctionSignature(PythonDateTime.class.getMethod("from_timestamp", PythonLikeType.class, PythonNumber.class, PythonLikeObject.class))); DATE_TIME_TYPE.addMethod("utcfromtimestamp", - ArgumentSpec.forFunctionReturning("utcfromtimestamp", PythonDate.class) - .addArgument("date_type", PythonLikeType.class) - .addArgument("timestamp", PythonNumber.class) + ArgumentSpec.forFunctionReturning("utcfromtimestamp", PythonDate.class.getName()) + .addArgument("date_type", PythonLikeType.class.getName()) + .addArgument("timestamp", PythonNumber.class.getName()) .asClassPythonFunctionSignature(PythonDateTime.class.getMethod("utc_from_timestamp", PythonLikeType.class, PythonNumber.class))); DATE_TIME_TYPE.addMethod("combine", - ArgumentSpec.forFunctionReturning("combine", PythonDateTime.class) - .addArgument("datetime_type", PythonLikeType.class) - .addArgument("date", PythonDate.class) - .addArgument("time", PythonTime.class) - .addNullableArgument("tzinfo", PythonLikeObject.class) + ArgumentSpec.forFunctionReturning("combine", PythonDateTime.class.getName()) + .addArgument("datetime_type", PythonLikeType.class.getName()) + .addArgument("date", PythonDate.class.getName()) + .addArgument("time", PythonTime.class.getName()) + .addNullableArgument("tzinfo", PythonLikeObject.class.getName()) .asClassPythonFunctionSignature( PythonDateTime.class.getMethod("combine", PythonLikeType.class, PythonDate.class, @@ -153,16 +153,16 @@ private static void registerMethods() throws NoSuchMethodException { // Instance methods DATE_TIME_TYPE.addMethod("replace", - ArgumentSpec.forFunctionReturning("replace", PythonDate.class) - .addNullableArgument("year", PythonInteger.class) - .addNullableArgument("month", PythonInteger.class) - .addNullableArgument("day", PythonInteger.class) - .addNullableArgument("hour", PythonInteger.class) - .addNullableArgument("minute", PythonInteger.class) - .addNullableArgument("second", PythonInteger.class) - .addNullableArgument("microsecond", PythonInteger.class) - .addNullableArgument("tzinfo", PythonLikeObject.class) - .addNullableKeywordOnlyArgument("fold", PythonInteger.class) + ArgumentSpec.forFunctionReturning("replace", PythonDate.class.getName()) + .addNullableArgument("year", PythonInteger.class.getName()) + .addNullableArgument("month", PythonInteger.class.getName()) + .addNullableArgument("day", PythonInteger.class.getName()) + .addNullableArgument("hour", PythonInteger.class.getName()) + .addNullableArgument("minute", PythonInteger.class.getName()) + .addNullableArgument("second", PythonInteger.class.getName()) + .addNullableArgument("microsecond", PythonInteger.class.getName()) + .addNullableArgument("tzinfo", PythonLikeObject.class.getName()) + .addNullableKeywordOnlyArgument("fold", PythonInteger.class.getName()) .asPythonFunctionSignature(PythonDateTime.class.getMethod("replace", PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonInteger.class, @@ -197,9 +197,9 @@ private static void registerMethods() throws NoSuchMethodException { PythonDateTime.class.getMethod("dst")); DATE_TIME_TYPE.addMethod("isoformat", - ArgumentSpec.forFunctionReturning("isoformat", PythonString.class) - .addArgument("sep", PythonString.class, PythonString.valueOf("T")) - .addArgument("timespec", PythonString.class, PythonString.valueOf("auto")) + ArgumentSpec.forFunctionReturning("isoformat", PythonString.class.getName()) + .addArgument("sep", PythonString.class.getName(), PythonString.valueOf("T")) + .addArgument("timespec", PythonString.class.getName(), PythonString.valueOf("auto")) .asPythonFunctionSignature( PythonDateTime.class.getMethod("iso_format", PythonString.class, PythonString.class))); diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonTime.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonTime.java index 69edacff..d6194970 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonTime.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonTime.java @@ -55,39 +55,39 @@ public class PythonTime extends AbstractPythonLikeObject implements PlanningImmu } private static void registerMethods() throws NoSuchMethodException { - TIME_TYPE.addConstructor(ArgumentSpec.forFunctionReturning("datetime.time", PythonTime.class) - .addArgument("hour", PythonInteger.class, PythonInteger.ZERO) - .addArgument("minute", PythonInteger.class, PythonInteger.ZERO) - .addArgument("second", PythonInteger.class, PythonInteger.ZERO) - .addArgument("microsecond", PythonInteger.class, PythonInteger.ZERO) - .addArgument("tzinfo", PythonLikeObject.class, PythonNone.INSTANCE) - .addKeywordOnlyArgument("fold", PythonInteger.class, PythonInteger.ZERO) + TIME_TYPE.addConstructor(ArgumentSpec.forFunctionReturning("datetime.time", PythonTime.class.getName()) + .addArgument("hour", PythonInteger.class.getName(), PythonInteger.ZERO) + .addArgument("minute", PythonInteger.class.getName(), PythonInteger.ZERO) + .addArgument("second", PythonInteger.class.getName(), PythonInteger.ZERO) + .addArgument("microsecond", PythonInteger.class.getName(), PythonInteger.ZERO) + .addArgument("tzinfo", PythonLikeObject.class.getName(), PythonNone.INSTANCE) + .addKeywordOnlyArgument("fold", PythonInteger.class.getName(), PythonInteger.ZERO) .asPythonFunctionSignature( PythonTime.class.getMethod("of", PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonLikeObject.class, PythonInteger.class))); TIME_TYPE.addMethod("fromisoformat", - ArgumentSpec.forFunctionReturning("fromisoformat", PythonTime.class) - .addArgument("time_string", PythonString.class) + ArgumentSpec.forFunctionReturning("fromisoformat", PythonTime.class.getName()) + .addArgument("time_string", PythonString.class.getName()) .asStaticPythonFunctionSignature(PythonTime.class.getMethod("from_iso_format", PythonString.class))); TIME_TYPE.addMethod("replace", - ArgumentSpec.forFunctionReturning("replace", PythonTime.class) - .addNullableArgument("hour", PythonInteger.class) - .addNullableArgument("minute", PythonInteger.class) - .addNullableArgument("second", PythonInteger.class) - .addNullableArgument("microsecond", PythonInteger.class) - .addNullableArgument("tzinfo", PythonLikeObject.class) - .addNullableKeywordOnlyArgument("fold", PythonInteger.class) + ArgumentSpec.forFunctionReturning("replace", PythonTime.class.getName()) + .addNullableArgument("hour", PythonInteger.class.getName()) + .addNullableArgument("minute", PythonInteger.class.getName()) + .addNullableArgument("second", PythonInteger.class.getName()) + .addNullableArgument("microsecond", PythonInteger.class.getName()) + .addNullableArgument("tzinfo", PythonLikeObject.class.getName()) + .addNullableKeywordOnlyArgument("fold", PythonInteger.class.getName()) .asPythonFunctionSignature(PythonTime.class.getMethod("replace", PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonInteger.class, PythonLikeObject.class, PythonInteger.class))); TIME_TYPE.addMethod("isoformat", - ArgumentSpec.forFunctionReturning("isoformat", PythonString.class) - .addArgument("timespec", PythonString.class, PythonString.valueOf("auto")) + ArgumentSpec.forFunctionReturning("isoformat", PythonString.class.getName()) + .addArgument("timespec", PythonString.class.getName(), PythonString.valueOf("auto")) .asPythonFunctionSignature(PythonTime.class.getMethod("isoformat", PythonString.class))); TIME_TYPE.addMethod("tzname", diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonTimeDelta.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonTimeDelta.java index dd836a19..97c0e288 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonTimeDelta.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/datetime/PythonTimeDelta.java @@ -53,14 +53,14 @@ public class PythonTimeDelta extends AbstractPythonLikeObject implements PythonL private static void registerMethods() throws NoSuchMethodException { // Constructor - TIME_DELTA_TYPE.addConstructor(ArgumentSpec.forFunctionReturning("timedelta", PythonTimeDelta.class) - .addArgument("days", PythonNumber.class, PythonInteger.ZERO) - .addArgument("seconds", PythonNumber.class, PythonInteger.ZERO) - .addArgument("microseconds", PythonNumber.class, PythonInteger.ZERO) - .addArgument("milliseconds", PythonNumber.class, PythonInteger.ZERO) - .addArgument("minutes", PythonNumber.class, PythonInteger.ZERO) - .addArgument("hours", PythonNumber.class, PythonInteger.ZERO) - .addArgument("weeks", PythonNumber.class, PythonInteger.ZERO) + TIME_DELTA_TYPE.addConstructor(ArgumentSpec.forFunctionReturning("timedelta", PythonTimeDelta.class.getName()) + .addArgument("days", PythonNumber.class.getName(), PythonInteger.ZERO) + .addArgument("seconds", PythonNumber.class.getName(), PythonInteger.ZERO) + .addArgument("microseconds", PythonNumber.class.getName(), PythonInteger.ZERO) + .addArgument("milliseconds", PythonNumber.class.getName(), PythonInteger.ZERO) + .addArgument("minutes", PythonNumber.class.getName(), PythonInteger.ZERO) + .addArgument("hours", PythonNumber.class.getName(), PythonInteger.ZERO) + .addArgument("weeks", PythonNumber.class.getName(), PythonInteger.ZERO) .asPythonFunctionSignature(PythonTimeDelta.class.getMethod("of", PythonNumber.class, PythonNumber.class, PythonNumber.class, PythonNumber.class, PythonNumber.class, PythonNumber.class, PythonNumber.class))); diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/util/arguments/ArgumentSpec.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/util/arguments/ArgumentSpec.java index 6695670c..b9c604a9 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/util/arguments/ArgumentSpec.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/util/arguments/ArgumentSpec.java @@ -25,11 +25,17 @@ import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple; import ai.timefold.jpyinterpreter.types.errors.TypeError; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; + public final class ArgumentSpec { - private final Class functionReturnType; + private static List> ARGUMENT_SPECS = new ArrayList<>(); + + private final String functionReturnTypeName; private final String functionName; private final List argumentNameList; - private final List> argumentTypeList; + private final List argumentTypeNameList; private final List argumentKindList; private final List argumentDefaultList; private final BitSet nullableArgumentSet; @@ -39,13 +45,16 @@ public final class ArgumentSpec { private final int numberOfPositionalArguments; private final int requiredPositionalArguments; - private ArgumentSpec(String functionName, Class functionReturnType) { - this.functionReturnType = functionReturnType; + private Class functionReturnType = null; + private List argumentTypeList = null; + + private ArgumentSpec(String functionName, String functionReturnTypeName) { + this.functionReturnTypeName = functionReturnTypeName; this.functionName = functionName + "()"; requiredPositionalArguments = 0; numberOfPositionalArguments = 0; argumentNameList = Collections.emptyList(); - argumentTypeList = Collections.emptyList(); + argumentTypeNameList = Collections.emptyList(); argumentKindList = Collections.emptyList(); argumentDefaultList = Collections.emptyList(); extraPositionalsArgumentIndex = Optional.empty(); @@ -53,11 +62,11 @@ private ArgumentSpec(String functionName, Class functionReturnType) { nullableArgumentSet = new BitSet(); } - private ArgumentSpec(String argumentName, Class argumentType, ArgumentKind argumentKind, Object defaultValue, + private ArgumentSpec(String argumentName, String argumentTypeName, ArgumentKind argumentKind, Object defaultValue, Optional extraPositionalsArgumentIndex, Optional extraKeywordsArgumentIndex, boolean allowNull, ArgumentSpec previousSpec) { functionName = previousSpec.functionName; - functionReturnType = previousSpec.functionReturnType; + functionReturnTypeName = previousSpec.functionReturnTypeName; if (previousSpec.numberOfPositionalArguments < previousSpec.getTotalArgumentCount()) { numberOfPositionalArguments = previousSpec.numberOfPositionalArguments; @@ -80,15 +89,15 @@ private ArgumentSpec(String argumentName, Class argumentType, ArgumentKind ar } argumentNameList = new ArrayList<>(previousSpec.argumentNameList.size() + 1); - argumentTypeList = new ArrayList<>(previousSpec.argumentTypeList.size() + 1); + argumentTypeNameList = new ArrayList<>(previousSpec.argumentTypeNameList.size() + 1); argumentKindList = new ArrayList<>(previousSpec.argumentKindList.size() + 1); argumentDefaultList = new ArrayList<>(previousSpec.argumentDefaultList.size() + 1); argumentNameList.addAll(previousSpec.argumentNameList); argumentNameList.add(argumentName); - argumentTypeList.addAll(previousSpec.argumentTypeList); - argumentTypeList.add(argumentType); + argumentTypeNameList.addAll(previousSpec.argumentTypeNameList); + argumentTypeNameList.add(argumentTypeName); argumentKindList.addAll(previousSpec.argumentKindList); argumentKindList.add(argumentKind); @@ -119,7 +128,7 @@ private ArgumentSpec(String argumentName, Class argumentType, ArgumentKind ar } public static ArgumentSpec forFunctionReturning(String functionName, - Class outClass) { + String outClass) { return new ArgumentSpec<>(functionName, outClass); } @@ -139,8 +148,8 @@ public boolean hasExtraKeywordArgumentsCapture() { return extraKeywordsArgumentIndex.isPresent(); } - public Class getArgumentType(int argumentIndex) { - return argumentTypeList.get(argumentIndex); + public String getArgumentTypeInternalName(int argumentIndex) { + return argumentTypeNameList.get(argumentIndex).replace('.', '/'); } public ArgumentKind getArgumentKind(int argumentIndex) { @@ -184,8 +193,43 @@ public Collection getUnspecifiedArgumentSet(int positionalArguments, Li .collect(Collectors.toList()); } + public static ArgumentSpec getArgumentSpec(int argumentSpecIndex) { + return ARGUMENT_SPECS.get(argumentSpecIndex); + } + + public void loadArgumentSpec(MethodVisitor methodVisitor) { + int index = ARGUMENT_SPECS.size(); + ARGUMENT_SPECS.add(this); + methodVisitor.visitLdcInsn(index); + methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(ArgumentSpec.class), + "getArgumentSpec", Type.getMethodDescriptor(Type.getType(ArgumentSpec.class), + Type.INT_TYPE), + false); + } + + private void computeArgumentTypeList() { + if (argumentTypeList == null) { + try { + functionReturnType = BuiltinTypes.asmClassLoader.loadClass(functionReturnTypeName.replace('/', '.')); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + argumentTypeList = argumentTypeNameList.stream() + .map(className -> { + try { + return (Class) BuiltinTypes.asmClassLoader.loadClass(className.replace('/', '.')); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + }) + .toList(); + } + } + public List extractArgumentList(List positionalArguments, Map keywordArguments) { + computeArgumentTypeList(); + List out = new ArrayList<>(argumentNameList.size()); if (positionalArguments.size() > numberOfPositionalArguments && @@ -305,6 +349,8 @@ public List extractArgumentList(List positio public boolean verifyMatchesCallSignature(int positionalArgumentCount, List keywordArgumentNameList, List callStackTypeList) { + computeArgumentTypeList(); + Set missingValue = getRequiredArgumentIndexSet(); for (int keywordIndex = 0; keywordIndex < keywordArgumentNameList.size(); keywordIndex++) { String keyword = keywordArgumentNameList.get(keywordIndex); @@ -373,73 +419,73 @@ private Set getRequiredArgumentIndexSet() { } private ArgumentSpec addArgument(String argumentName, - Class argumentType, ArgumentKind argumentKind, ArgumentType_ defaultValue, + String argumentTypeName, ArgumentKind argumentKind, ArgumentType_ defaultValue, Optional extraPositionalsArgumentIndex, Optional extraKeywordsArgumentIndex, boolean allowNull) { - return new ArgumentSpec<>(argumentName, argumentType, argumentKind, defaultValue, + return new ArgumentSpec<>(argumentName, argumentTypeName, argumentKind, defaultValue, extraPositionalsArgumentIndex, extraKeywordsArgumentIndex, allowNull, this); } public ArgumentSpec addArgument(String argumentName, - Class argumentType) { - return addArgument(argumentName, argumentType, ArgumentKind.POSITIONAL_AND_KEYWORD, null, + String argumentTypeName) { + return addArgument(argumentName, argumentTypeName, ArgumentKind.POSITIONAL_AND_KEYWORD, null, Optional.empty(), Optional.empty(), false); } public ArgumentSpec - addPositionalOnlyArgument(String argumentName, Class argumentType) { - return addArgument(argumentName, argumentType, ArgumentKind.POSITIONAL_ONLY, null, + addPositionalOnlyArgument(String argumentName, String argumentTypeName) { + return addArgument(argumentName, argumentTypeName, ArgumentKind.POSITIONAL_ONLY, null, Optional.empty(), Optional.empty(), false); } public ArgumentSpec - addKeywordOnlyArgument(String argumentName, Class argumentType) { - return addArgument(argumentName, argumentType, ArgumentKind.KEYWORD_ONLY, null, + addKeywordOnlyArgument(String argumentName, String argumentTypeName) { + return addArgument(argumentName, argumentTypeName, ArgumentKind.KEYWORD_ONLY, null, Optional.empty(), Optional.empty(), false); } public ArgumentSpec addArgument(String argumentName, - Class argumentType, ArgumentType_ defaultValue) { - return addArgument(argumentName, argumentType, ArgumentKind.POSITIONAL_AND_KEYWORD, defaultValue, + String argumentTypeName, ArgumentType_ defaultValue) { + return addArgument(argumentName, argumentTypeName, ArgumentKind.POSITIONAL_AND_KEYWORD, defaultValue, Optional.empty(), Optional.empty(), false); } public ArgumentSpec - addPositionalOnlyArgument(String argumentName, Class argumentType, ArgumentType_ defaultValue) { - return addArgument(argumentName, argumentType, ArgumentKind.POSITIONAL_ONLY, defaultValue, + addPositionalOnlyArgument(String argumentName, String argumentTypeName, ArgumentType_ defaultValue) { + return addArgument(argumentName, argumentTypeName, ArgumentKind.POSITIONAL_ONLY, defaultValue, Optional.empty(), Optional.empty(), false); } public ArgumentSpec - addKeywordOnlyArgument(String argumentName, Class argumentType, ArgumentType_ defaultValue) { - return addArgument(argumentName, argumentType, ArgumentKind.KEYWORD_ONLY, defaultValue, + addKeywordOnlyArgument(String argumentName, String argumentTypeName, ArgumentType_ defaultValue) { + return addArgument(argumentName, argumentTypeName, ArgumentKind.KEYWORD_ONLY, defaultValue, Optional.empty(), Optional.empty(), false); } public ArgumentSpec addNullableArgument(String argumentName, - Class argumentType) { - return addArgument(argumentName, argumentType, ArgumentKind.POSITIONAL_AND_KEYWORD, null, + String argumentTypeName) { + return addArgument(argumentName, argumentTypeName, ArgumentKind.POSITIONAL_AND_KEYWORD, null, Optional.empty(), Optional.empty(), true); } public ArgumentSpec addNullablePositionalOnlyArgument(String argumentName, - Class argumentType) { - return addArgument(argumentName, argumentType, ArgumentKind.POSITIONAL_ONLY, null, + String argumentTypeName) { + return addArgument(argumentName, argumentTypeName, ArgumentKind.POSITIONAL_ONLY, null, Optional.empty(), Optional.empty(), true); } public ArgumentSpec addNullableKeywordOnlyArgument(String argumentName, - Class argumentType) { - return addArgument(argumentName, argumentType, ArgumentKind.KEYWORD_ONLY, null, + String argumentTypeName) { + return addArgument(argumentName, argumentTypeName, ArgumentKind.KEYWORD_ONLY, null, Optional.empty(), Optional.empty(), true); } public ArgumentSpec addExtraPositionalVarArgument(String argumentName) { - return addArgument(argumentName, PythonLikeTuple.class, ArgumentKind.VARARGS, null, + return addArgument(argumentName, PythonLikeTuple.class.getName(), ArgumentKind.VARARGS, null, Optional.of(getTotalArgumentCount()), Optional.empty(), false); } public ArgumentSpec addExtraKeywordVarArgument(String argumentName) { - return addArgument(argumentName, PythonLikeDict.class, ArgumentKind.VARARGS, null, + return addArgument(argumentName, PythonLikeDict.class.getName(), ArgumentKind.VARARGS, null, Optional.empty(), Optional.of(getTotalArgumentCount()), false); } @@ -501,10 +547,12 @@ public PythonFunctionSignature asClassPythonFunctionSignature(String internalCla } private void verifyMethodMatchesSpec(Method method) { + computeArgumentTypeList(); + if (!functionReturnType.isAssignableFrom(method.getReturnType())) { throw new IllegalArgumentException("Method (" + method + ") does not match the given spec (" + this + "): its return type (" + method.getReturnType() + ") is not " + - "assignable to the spec return type (" + functionReturnType + ")."); + "assignable to the spec return type (" + functionReturnTypeName + ")."); } if (method.getParameterCount() != argumentNameList.size()) { @@ -525,6 +573,8 @@ private void verifyMethodMatchesSpec(Method method) { @SuppressWarnings("unchecked") private PythonFunctionSignature getPythonFunctionSignatureForMethodDescriptor(MethodDescriptor methodDescriptor, Class javaReturnType) { + computeArgumentTypeList(); + int firstDefault = 0; while (firstDefault < argumentDefaultList.size() && argumentDefaultList.get(firstDefault) == null && @@ -559,15 +609,19 @@ private PythonFunctionSignature getPythonFunctionSignatureForMethodDescriptor(Me this); } + public Object getDefaultValue(int defaultIndex) { + return argumentDefaultList.get(defaultIndex); + } + @Override public String toString() { StringBuilder out = new StringBuilder("ArgumentSpec("); out.append("name=").append(functionName) - .append(", returnType=").append(functionReturnType) + .append(", returnType=").append(functionReturnTypeName) .append(", arguments=["); for (int i = 0; i < argumentNameList.size(); i++) { - out.append(argumentTypeList.get(i)); + out.append(argumentTypeNameList.get(i)); out.append(" "); out.append(argumentNameList.get(i)); diff --git a/jpyinterpreter/src/main/python/__init__.py b/jpyinterpreter/src/main/python/__init__.py index f3dfb8e9..61e8114d 100644 --- a/jpyinterpreter/src/main/python/__init__.py +++ b/jpyinterpreter/src/main/python/__init__.py @@ -2,7 +2,7 @@ This module acts as an interface to the Python bytecode to Java bytecode interpreter """ from .jvm_setup import init, set_class_output_directory -from .annotations import JavaAnnotation, add_class_annotation, add_java_interface +from .annotations import JavaAnnotation, AnnotationValueSupplier, add_class_annotation, add_java_interface from .conversions import (convert_to_java_python_like_object, unwrap_python_like_object, update_python_object_from_java, is_c_native) from .translator import (translate_python_bytecode_to_java_bytecode, diff --git a/jpyinterpreter/src/main/python/annotations.py b/jpyinterpreter/src/main/python/annotations.py index c65bd506..abff81a8 100644 --- a/jpyinterpreter/src/main/python/annotations.py +++ b/jpyinterpreter/src/main/python/annotations.py @@ -1,9 +1,19 @@ +from collections import defaultdict from dataclasses import dataclass -from types import FunctionType +from types import FunctionType, NoneType, UnionType from typing import TypeVar, Any, List, Tuple, Dict, Union, Annotated, Type, Callable, \ get_origin, get_args, get_type_hints from jpype import JClass, JArray + +class AnnotationValueSupplier: + def __init__(self, supplier: Callable[[], Any]): + self.supplier = supplier + + def get_value(self) -> Any: + return self.supplier() + + @dataclass class JavaAnnotation: annotation_type: JClass @@ -49,7 +59,11 @@ def copy_type_annotations(hinted_object, default_args, vargs_name, kwargs_name): from .translator import type_to_compiled_java_class out = HashMap() - type_hints = get_type_hints(hinted_object, include_extras=True) + try: + type_hints = get_type_hints(hinted_object, include_extras=True) + except NameError: + # Occurs if get_type_hints cannot resolve a forward reference + type_hints = hinted_object.__annotations__ if hasattr(hinted_object, '__annotations__') else {} for name, type_hint in type_hints.items(): if not isinstance(name, str): @@ -60,11 +74,12 @@ def copy_type_annotations(hinted_object, default_args, vargs_name, kwargs_name): if name == kwargs_name: out.put(name, TypeHint.withoutAnnotations(type_to_compiled_java_class[dict])) continue + hint_type = type_hint hint_annotations = Collections.emptyList() if get_origin(type_hint) is Annotated: hint_type = get_args(type_hint)[0] - hint_annotations = get_java_annotations(type_hint.__metadata__) + hint_annotations = get_java_annotations(type_hint.__metadata__) # noqa if name in default_args: hint_type = Union[hint_type, type(default_args[name])] @@ -75,6 +90,20 @@ def copy_type_annotations(hinted_object, default_args, vargs_name, kwargs_name): return out +def find_closest_common_ancestor(*cls_list: type): + mros = [(list(cls.__mro__) if hasattr(cls, '__mro__') else [cls]) for cls in cls_list] + track = defaultdict(int) + while mros: + for mro in mros: + cur = mro.pop(0) + track[cur] += 1 + if track[cur] == len(cls_list): + return cur + if len(mro) == 0: + mros.remove(mro) + return object + + def get_java_type_hint(hint_type): from .translator import get_java_type_for_python_type, type_to_compiled_java_class from typing import get_args as get_generic_args @@ -83,6 +112,7 @@ def get_java_type_hint(hint_type): from ai.timefold.jpyinterpreter import TypeHint from ai.timefold.jpyinterpreter.types import BuiltinTypes from ai.timefold.jpyinterpreter.types.wrappers import JavaObjectWrapper + origin_type = get_origin(hint_type) if origin_type is None: # Happens for Callable[[parameter_types], return_type] @@ -100,12 +130,28 @@ def get_java_type_hint(hint_type): else: return TypeHint(BuiltinTypes.BASE_TYPE, Collections.emptyList()) - origin_type_hint = get_java_type_hint(origin_type) generic_args = get_generic_args(hint_type) + + if origin_type is Union or origin_type is UnionType: + union_types_excluding_none = [] + union_types_including_none = [] + for union_type in generic_args: + union_types_including_none.append(union_type) + if union_type == NoneType: + continue + union_types_excluding_none.append(union_type) + + return TypeHint(get_java_type_hint(find_closest_common_ancestor(*union_types_including_none)).type(), + Collections.emptyList(), + get_java_type_hint(find_closest_common_ancestor(*union_types_excluding_none)).type()) + + origin_type_hint = get_java_type_hint(origin_type) generic_arg_type_hint_array = JArray(TypeHint)(len(generic_args)) for i in range(len(generic_args)): generic_arg_type_hint_array[i] = get_java_type_hint(generic_args[i]) - return TypeHint(origin_type_hint.type(), Collections.emptyList(), generic_arg_type_hint_array) + + return TypeHint(origin_type_hint.type(), Collections.emptyList(), generic_arg_type_hint_array, + origin_type_hint.type()) def get_java_annotations(annotated_metadata: List[Any]): @@ -130,6 +176,9 @@ def convert_java_annotation(java_annotation: JavaAnnotation): from ai.timefold.jpyinterpreter import AnnotationMetadata annotation_values = HashMap() for attribute_name, attribute_value in java_annotation.annotation_values.items(): + if isinstance(attribute_value, AnnotationValueSupplier): + attribute_value = attribute_value.get_value() + annotation_method = java_annotation.annotation_type.class_.getDeclaredMethod(attribute_name) attribute_type = annotation_method.getReturnType() java_attribute_value = convert_annotation_value(java_annotation.annotation_type, attribute_type, @@ -144,6 +193,8 @@ def convert_annotation_value(annotation_type: JClass, attribute_type: JClass, at translate_python_bytecode_to_java_bytecode, generate_proxy_class_for_translated_function) from jpype import JBoolean, JByte, JChar, JShort, JInt, JLong, JFloat, JDouble, JString, JArray + from ai.timefold.jpyinterpreter import AnnotationMetadata + from org.objectweb.asm import Type as ASMType if attribute_value is None: return None @@ -167,21 +218,26 @@ def convert_annotation_value(annotation_type: JClass, attribute_type: JClass, at elif attribute_type == JClass('java.lang.String').class_: return JString(attribute_value) elif attribute_type == JClass('java.lang.Class').class_: - if isinstance(attribute_value, JClass('java.lang.Class')): + if isinstance(attribute_value, ASMType): return attribute_value + if isinstance(attribute_value, JClass('java.lang.Class')): + return AnnotationMetadata.getValueAsType(attribute_value.getName()) elif isinstance(attribute_value, type): - return get_java_type_for_python_type(attribute_type) + out = get_java_type_for_python_type(attribute_type) + return AnnotationMetadata.getValueAsType(out.getJavaTypeInternalName()) elif isinstance(attribute_value, FunctionType): method = annotation_type.class_.getDeclaredMethod(attribute_name) generic_type = method.getGenericReturnType() try: function_type_and_generic_args = resolve_java_function_type_as_tuple(generic_type) instance = translate_python_bytecode_to_java_bytecode(attribute_value, *function_type_and_generic_args) - return generate_proxy_class_for_translated_function(function_type_and_generic_args[0], instance) + return AnnotationMetadata.getValueAsType(generate_proxy_class_for_translated_function( + function_type_and_generic_args[0], instance).getName()) except ValueError: raw_type = resolve_raw_type(generic_type.getActualTypeArguments()[0]) instance = translate_python_bytecode_to_java_bytecode(attribute_value, raw_type) - return generate_proxy_class_for_translated_function(raw_type, instance) + return AnnotationMetadata.getValueAsType(generate_proxy_class_for_translated_function( + raw_type, instance).getName()) else: raise ValueError(f'Illegal value for {attribute_name} in annotation {annotation_type}: {attribute_value}') elif attribute_type.isEnum(): diff --git a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/arguments/ArgumentSpecTest.java b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/arguments/ArgumentSpecTest.java index 28058443..55a5daa2 100644 --- a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/arguments/ArgumentSpecTest.java +++ b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/arguments/ArgumentSpecTest.java @@ -21,7 +21,7 @@ public class ArgumentSpecTest { @Test public void testSpec() { - ArgumentSpec current = ArgumentSpec.forFunctionReturning("myFunction", PythonLikeTuple.class); + ArgumentSpec current = ArgumentSpec.forFunctionReturning("myFunction", PythonLikeTuple.class.getName()); List argumentNameList = new ArrayList<>(); List argumentValueList = new ArrayList<>(); @@ -39,7 +39,7 @@ public void testSpec() { assertThat(out).containsExactlyElementsOf(argumentValueList); } - current = current.addArgument("arg" + i, PythonInteger.class); + current = current.addArgument("arg" + i, PythonInteger.class.getName()); argumentNameList.add("arg" + i); argumentValueList.add(PythonInteger.valueOf(i)); } @@ -47,7 +47,7 @@ public void testSpec() { @Test public void testSpecWithDefaults() { - ArgumentSpec current = ArgumentSpec.forFunctionReturning("myFunction", PythonLikeTuple.class); + ArgumentSpec current = ArgumentSpec.forFunctionReturning("myFunction", PythonLikeTuple.class.getName()); List argumentNameList = new ArrayList<>(); List argumentValueList = new ArrayList<>(); @@ -73,7 +73,7 @@ public void testSpecWithDefaults() { } } - current = current.addArgument("arg" + i, PythonInteger.class, PythonInteger.valueOf(-i)); + current = current.addArgument("arg" + i, PythonInteger.class.getName(), PythonInteger.valueOf(-i)); argumentNameList.add("arg" + i); argumentValueList.add(PythonInteger.valueOf(i)); } @@ -81,8 +81,8 @@ public void testSpecWithDefaults() { @Test public void testSpecMissingArgument() { - ArgumentSpec current = ArgumentSpec.forFunctionReturning("myFunction", PythonLikeTuple.class) - .addArgument("_arg0", PythonInteger.class); + ArgumentSpec current = ArgumentSpec.forFunctionReturning("myFunction", PythonLikeTuple.class.getName()) + .addArgument("_arg0", PythonInteger.class.getName()); List argumentNameList = new ArrayList<>(); List argumentValueList = new ArrayList<>(); @@ -105,7 +105,7 @@ public void testSpecMissingArgument() { .hasMessageContaining("myFunction() missing 1 required positional argument: '"); } - current = current.addArgument("arg" + i, PythonInteger.class); + current = current.addArgument("arg" + i, PythonInteger.class.getName()); argumentNameList.add("arg" + i); argumentValueList.add(PythonInteger.valueOf(i)); } @@ -113,7 +113,7 @@ public void testSpecMissingArgument() { @Test public void testSpecExtraArgument() { - ArgumentSpec current = ArgumentSpec.forFunctionReturning("myFunction", PythonLikeTuple.class); + ArgumentSpec current = ArgumentSpec.forFunctionReturning("myFunction", PythonLikeTuple.class.getName()); List argumentNameList = new ArrayList<>(); List argumentValueList = new ArrayList<>(); @@ -147,7 +147,7 @@ public void testSpecExtraArgument() { .containsAnyOf(possibleErrorMessages); } - current = current.addArgument("arg" + i, PythonInteger.class); + current = current.addArgument("arg" + i, PythonInteger.class.getName()); argumentNameList.add("arg" + i); argumentValueList.add(PythonInteger.valueOf(i)); } diff --git a/jpyinterpreter/tests/test_classes.py b/jpyinterpreter/tests/test_classes.py index d6f8fb0f..e887ae7b 100644 --- a/jpyinterpreter/tests/test_classes.py +++ b/jpyinterpreter/tests/test_classes.py @@ -964,6 +964,28 @@ class A: assert field_type.getActualTypeArguments()[0].getName() == PythonString.class_.getName() +def test_getter_type(): + from typing import Optional, Union + from ai.timefold.jpyinterpreter.types import PythonString, PythonBytes + from ai.timefold.jpyinterpreter.types.numeric import PythonInteger + from jpyinterpreter import translate_python_class_to_java_class + + class A: + str_field: Optional[str] + int_field: Union[int, None] + bytes_field: bytes | None + + translated_class = translate_python_class_to_java_class(A).getJavaClass() + str_field_getter_type = translated_class.getMethod('getStr_field').getReturnType() + assert str_field_getter_type == PythonString.class_ + + int_field_getter_type = translated_class.getMethod('getInt_field').getReturnType() + assert int_field_getter_type == PythonInteger.class_ + + bytes_field_getter_type = translated_class.getMethod('getBytes_field').getReturnType() + assert bytes_field_getter_type == PythonBytes.class_ + + def test_marker_interface(): from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference from jpyinterpreter import translate_python_class_to_java_class, add_java_interface diff --git a/tests/test_vehicle_routing.py b/tests/test_vehicle_routing.py new file mode 100644 index 00000000..2c734640 --- /dev/null +++ b/tests/test_vehicle_routing.py @@ -0,0 +1,306 @@ +from datetime import datetime, timedelta + +from timefold.solver.api import * +from timefold.solver.annotation import * +from timefold.solver.config import * +from timefold.solver.constraint import ConstraintFactory +from timefold.solver.score import * + +from typing import Annotated, List, Optional +from dataclasses import dataclass, field + + +@dataclass +class Location: + latitude: float + longitude: float + driving_time_seconds: dict[int, int] = field(default_factory=dict) + + def driving_time_to(self, other: 'Location') -> int: + return self.driving_time_seconds[id(other)] + + +class ArrivalTimeUpdatingVariableListener(VariableListener): + def after_variable_changed(self, score_director: ScoreDirector, visit: 'Visit') -> None: + if visit.vehicle is None: + if visit.arrival_time is not None: + score_director.before_variable_changed(visit, 'arrival_time') + visit.arrival_time = None + score_director.after_variable_changed(visit, 'arrival_time') + return + previous_visit = visit.previous_visit + departure_time = visit.vehicle.departure_time if previous_visit is None else previous_visit.departure_time() + next_visit = visit + arrival_time = ArrivalTimeUpdatingVariableListener.calculate_arrival_time(next_visit, departure_time) + while next_visit is not None and next_visit.arrival_time != arrival_time: + score_director.before_variable_changed(next_visit, 'arrival_time') + next_visit.arrival_time = arrival_time + score_director.after_variable_changed(next_visit, 'arrival_time') + departure_time = next_visit.departure_time() + next_visit = next_visit.next_visit + arrival_time = ArrivalTimeUpdatingVariableListener.calculate_arrival_time(next_visit, departure_time) + + @staticmethod + def calculate_arrival_time(visit: Optional['Visit'], previous_departure_time: Optional[datetime]) \ + -> datetime | None: + if visit is None or previous_departure_time is None: + return None + return previous_departure_time + timedelta(seconds=visit.driving_time_seconds_from_previous_standstill()) + + +@planning_entity +@dataclass +class Visit: + id: Annotated[str, PlanningId] + name: str + location: Location + demand: int + min_start_time: datetime + max_end_time: datetime + service_duration: timedelta + vehicle: Annotated[Optional['Vehicle'], InverseRelationShadowVariable(source_variable_name='visits')] = ( + field(default=None)) + previous_visit: Annotated[Optional['Visit'], PreviousElementShadowVariable(source_variable_name='visits')] = ( + field(default=None)) + next_visit: Annotated[Optional['Visit'], + NextElementShadowVariable(source_variable_name='visits')] = field(default=None) + arrival_time: Annotated[Optional[datetime], + ShadowVariable(variable_listener_class=ArrivalTimeUpdatingVariableListener, + source_variable_name='vehicle'), + ShadowVariable(variable_listener_class=ArrivalTimeUpdatingVariableListener, + source_variable_name='previous_visit')] = field(default=None) + + def departure_time(self) -> Optional[datetime]: + if self.arrival_time is None: + return None + + return self.arrival_time + self.service_duration + + def start_service_time(self) -> Optional[datetime]: + if self.arrival_time is None: + return None + return self.min_start_time if (self.min_start_time < self.arrival_time) else self.arrival_time + + def is_service_finished_after_max_end_time(self) -> bool: + return self.arrival_time is not None and self.departure_time() > self.max_end_time + + def service_finished_delay_in_minutes(self) -> int: + if self.arrival_time is None: + return 0 + return (self.max_end_time - self.departure_time()).seconds // 60 + + def driving_time_seconds_from_previous_standstill(self) -> int: + if self.vehicle is None: + raise ValueError("This method must not be called when the shadow variables are not initialized yet.") + + if self.previous_visit is None: + return self.vehicle.home_location.driving_time_to(self.location) + else: + return self.previous_visit.location.driving_time_to(self.location) + + def driving_time_seconds_from_previous_standstill_or_none(self) -> Optional[int]: + if self.vehicle is None: + return None + return self.driving_time_seconds_from_previous_standstill() + + def __str__(self): + return self.id + + +@planning_entity +@dataclass +class Vehicle: + id: Annotated[str, PlanningId] + capacity: int + home_location: Location + departure_time: datetime + visits: Annotated[list[Visit], PlanningListVariable] = field(default_factory=list) + + def total_demand(self) -> int: + total_demand = 0 + for visit in self.visits: + total_demand += visit.demand + return total_demand + + def total_driving_time_seconds(self) -> int: + if len(self.visits) == 0: + return 0 + total_driving_time_seconds = 0 + previous_location = self.home_location + + for visit in self.visits: + total_driving_time_seconds += previous_location.driving_time_to(visit.location) + previous_location = visit.location + + total_driving_time_seconds += previous_location.driving_time_to(self.home_location) + return total_driving_time_seconds + + def arrival_time(self): + if len(self.visits) == 0: + return self.departure_time + + last_visit = self.visits[-1] + return (last_visit.departure_time() + + timedelta(seconds=last_visit.location.driving_time_to(self.home_location))) + + +@planning_solution +@dataclass +class VehicleRoutePlan: + vehicles: Annotated[list[Vehicle], PlanningEntityCollectionProperty] + visits: Annotated[list[Visit], PlanningEntityCollectionProperty, ValueRangeProvider] + score: Annotated[HardSoftScore, PlanningScore] = field(default=None) + + +@constraint_provider +def vehicle_routing_constraints(factory: ConstraintFactory): + return [ + vehicle_capacity(factory), + service_finished_after_max_end_time(factory), + minimize_travel_time(factory) + ] + +############################################## +# Hard constraints +############################################## + + +def vehicle_capacity(factory: ConstraintFactory): + return (factory.for_each(Vehicle) + .filter(lambda vehicle: vehicle.total_demand() > vehicle.capacity) + .penalize(HardSoftScore.ONE_HARD, + lambda vehicle: vehicle.total_demand() - vehicle.capacity) + .as_constraint('VEHICLE_CAPACITY') + ) + + +def service_finished_after_max_end_time(factory: ConstraintFactory): + return (factory.for_each(Visit) + .filter(lambda visit: visit.is_service_finished_after_max_end_time()) + .penalize(HardSoftScore.ONE_HARD, + lambda visit: visit.service_finished_delay_in_minutes()) + .as_constraint('SERVICE_FINISHED_AFTER_MAX_END_TIME') + ) + +############################################## +# Soft constraints +############################################## + + +def minimize_travel_time(factory: ConstraintFactory): + return ( + factory.for_each(Vehicle) + .penalize(HardSoftScore.ONE_SOFT, + lambda vehicle: vehicle.total_driving_time_seconds()) + .as_constraint('MINIMIZE_TRAVEL_TIME') + ) + + +def test_vrp(): + solver_config = SolverConfig( + solution_class=VehicleRoutePlan, + entity_class_list=[Vehicle, Visit], + score_director_factory_config=ScoreDirectorFactoryConfig( + constraint_provider_function=vehicle_routing_constraints + ), + termination_config=TerminationConfig( + best_score_limit='0hard/-300soft' + ) + ) + + solver = SolverFactory.create(solver_config).build_solver() + l1 = Location(1, 1) + l2 = Location(2, 2) + l3 = Location(3, 3) + l4 = Location(4, 4) + l5 = Location(5, 5) + + l1.driving_time_seconds = { + id(l1): 0, + id(l2): 60, + id(l3): 60 * 60, + id(l4): 60 * 60, + id(l5): 60 * 60 + } + + l2.driving_time_seconds = { + id(l1): 60 * 60, + id(l2): 0, + id(l3): 60, + id(l4): 60 * 60, + id(l5): 60 * 60 + } + + l3.driving_time_seconds = { + id(l1): 60, + id(l2): 60 * 60, + id(l3): 0, + id(l4): 60 * 60, + id(l5): 60 * 60 + } + + l4.driving_time_seconds = { + id(l1): 60 * 60, + id(l2): 60 * 60, + id(l3): 60 * 60, + id(l4): 0, + id(l5): 60 + } + + l5.driving_time_seconds = { + id(l1): 60 * 60, + id(l2): 60 * 60, + id(l3): 60 * 60, + id(l4): 60, + id(l5): 0 + } + + problem = VehicleRoutePlan( + vehicles=[ + Vehicle( + id='A', + capacity=3, + home_location=l1, + departure_time=datetime(2020, 1, 1), + ), + Vehicle( + id='B', + capacity=3, + home_location=l4, + departure_time=datetime(2020, 1, 1), + ), + ], + visits=[ + Visit( + id='1', + name='1', + location=l2, + demand=1, + min_start_time=datetime(2020, 1, 1), + max_end_time=datetime(2020, 1, 1, hour=10), + service_duration=timedelta(hours=1), + ), + Visit( + id='2', + name='2', + location=l3, + demand=1, + min_start_time=datetime(2020, 1, 1), + max_end_time=datetime(2020, 1, 1, hour=10), + service_duration=timedelta(hours=1), + ), + Visit( + id='3', + name='3', + location=l5, + demand=1, + min_start_time=datetime(2020, 1, 1), + max_end_time=datetime(2020, 1, 1, hour=10), + service_duration=timedelta(hours=1), + ), + ] + ) + solution = solver.solve(problem) + + assert [visit.id for visit in solution.vehicles[0].visits] == ['1', '2'] + assert [visit.id for visit in solution.vehicles[1].visits] == ['3'] diff --git a/timefold-solver-python-core/src/main/python/_timefold_java_interop.py b/timefold-solver-python-core/src/main/python/_timefold_java_interop.py index 08b34b2d..38a24585 100644 --- a/timefold-solver-python-core/src/main/python/_timefold_java_interop.py +++ b/timefold-solver-python-core/src/main/python/_timefold_java_interop.py @@ -3,7 +3,7 @@ import jpype.imports from jpype.types import * import importlib.resources -from typing import cast, List, Type, TypeVar, Callable, Union, TYPE_CHECKING +from typing import cast, List, Type, TypeVar, Callable, Union, TYPE_CHECKING, Any from ._jpype_type_conversions import PythonSupplier, ConstraintProviderFunction if TYPE_CHECKING: @@ -17,6 +17,8 @@ Solution_ = TypeVar('Solution_') ProblemId_ = TypeVar('ProblemId_') Score_ = TypeVar('Score_') + +_compilation_queue: list[type] = [] _enterprise_installed: bool = False @@ -140,6 +142,37 @@ def get_class(python_class: Union[Type, Callable]) -> JClass: return cast(JClass, Object) +def get_asm_type(python_class: Union[Type, Callable]) -> Any: + """Return the ASM type for the given Python Class""" + from java.lang import Object, Class + from ai.timefold.jpyinterpreter import AnnotationMetadata + from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference + from jpyinterpreter import is_c_native, get_java_type_for_python_type + + if python_class is None: + return None + if isinstance(python_class, jpype.JClass): + return AnnotationMetadata.getValueAsType(python_class.class_.getName()) + if isinstance(python_class, Class): + return AnnotationMetadata.getValueAsType(python_class.getName()) + if python_class == int: + from java.lang import Integer + return AnnotationMetadata.getValueAsType(Integer.class_.getName()) + if python_class == str: + from java.lang import String + return AnnotationMetadata.getValueAsType(String.class_.getName()) + if python_class == bool: + from java.lang import Boolean + return AnnotationMetadata.getValueAsType(Boolean.class_.getName()) + if hasattr(python_class, '_timefold_java_class'): + return AnnotationMetadata.getValueAsType(python_class._timefold_java_class.getName()) + if isinstance(python_class, type): + return AnnotationMetadata.getValueAsType(get_java_type_for_python_type(python_class).getJavaTypeInternalName()) + if is_c_native(python_class): + return AnnotationMetadata.getValueAsType(OpaquePythonReference.class_.getName()) + return AnnotationMetadata.getValueAsType(Object.class_.getName()) + + def register_java_class(python_object: Solution_, java_class: JClass) -> Solution_: python_object._timefold_java_class = java_class @@ -198,25 +231,29 @@ def __exit__(self, exc_type, exc_val, exc_tb): current_thread.setContextClassLoader(self.thread_class_loader) -def compile_and_get_class(python_class): +def compile_class(python_class: type) -> None: from jpyinterpreter import translate_python_class_to_java_class ensure_init() class_identifier = _get_class_identifier_for_object(python_class) out = translate_python_class_to_java_class(python_class).getJavaClass() class_identifier_to_java_class_map[class_identifier] = out - return out -def _generate_problem_fact_class(python_class): - return compile_and_get_class(python_class) +def _add_to_compilation_queue(python_class: type | PythonSupplier) -> None: + global _compilation_queue + _compilation_queue.append(python_class) + +def _process_compilation_queue() -> None: + global _compilation_queue -def _generate_planning_entity_class(python_class: Type) -> JClass: - return compile_and_get_class(python_class) + while len(_compilation_queue) > 0: + python_class = _compilation_queue.pop(0) + if isinstance(python_class, PythonSupplier): + python_class = python_class.get() -def _generate_planning_solution_class(python_class: Type) -> JClass: - return compile_and_get_class(python_class) + compile_class(python_class) def _to_constraint_java_array(python_list: List['_Constraint']) -> JArray: diff --git a/timefold-solver-python-core/src/main/python/annotation/_annotations.py b/timefold-solver-python-core/src/main/python/annotation/_annotations.py index 4c51bfc4..c897f093 100644 --- a/timefold-solver-python-core/src/main/python/annotation/_annotations.py +++ b/timefold-solver-python-core/src/main/python/annotation/_annotations.py @@ -1,9 +1,10 @@ import jpype +from .._jpype_type_conversions import PythonSupplier from ..api import VariableListener from ..constraint import ConstraintFactory -from .._timefold_java_interop import ensure_init, _generate_constraint_provider_class, register_java_class -from jpyinterpreter import JavaAnnotation +from .._timefold_java_interop import ensure_init, _generate_constraint_provider_class, register_java_class, get_asm_type +from jpyinterpreter import JavaAnnotation, AnnotationValueSupplier from jpype import JImplements, JOverride from typing import Union, List, Callable, Type, TYPE_CHECKING, TypeVar @@ -59,37 +60,21 @@ def __init__(self, *, }) -class PlanningVariableReference(JavaAnnotation): - def __init__(self, *, - entity_class: Type = None, - variable_name: str): - ensure_init() - from ai.timefold.solver.core.api.domain.variable import ( - PlanningVariableReference as JavaPlanningVariableReference) - super().__init__(JavaPlanningVariableReference, - { - 'variableName': variable_name, - 'entityClass': entity_class, - }) - - class ShadowVariable(JavaAnnotation): def __init__(self, *, variable_listener_class: Type[VariableListener] = None, source_variable_name: str, source_entity_class: Type = None): ensure_init() - from .._timefold_java_interop import get_class - from jpyinterpreter import get_java_type_for_python_type from ai.timefold.jpyinterpreter import PythonClassTranslator - from ai.timefold.solver.core.api.domain.variable import ( - ShadowVariable as JavaShadowVariable, VariableListener as JavaVariableListener) + from ai.timefold.solver.core.api.domain.variable import ShadowVariable as JavaShadowVariable super().__init__(JavaShadowVariable, { - 'variableListenerClass': get_class(variable_listener_class), + 'variableListenerClass': AnnotationValueSupplier( + lambda: get_asm_type(variable_listener_class)), 'sourceVariableName': PythonClassTranslator.getJavaFieldName(source_variable_name), - 'sourceEntityClass': get_class(source_entity_class), + 'sourceEntityClass': source_entity_class, }) @@ -98,14 +83,13 @@ def __init__(self, *, shadow_variable_name: str, shadow_entity_class: Type = None): ensure_init() - from .._timefold_java_interop import get_class from ai.timefold.jpyinterpreter import PythonClassTranslator from ai.timefold.solver.core.api.domain.variable import ( PiggybackShadowVariable as JavaPiggybackShadowVariable) super().__init__(JavaPiggybackShadowVariable, { 'shadowVariableName': PythonClassTranslator.getJavaFieldName(shadow_variable_name), - 'shadowEntityClass': get_class(shadow_entity_class), + 'shadowEntityClass': shadow_entity_class, }) @@ -122,6 +106,32 @@ def __init__(self, *, }) +class PreviousElementShadowVariable(JavaAnnotation): + def __init__(self, *, + source_variable_name: str): + ensure_init() + from ai.timefold.jpyinterpreter import PythonClassTranslator + from ai.timefold.solver.core.api.domain.variable import ( + PreviousElementShadowVariable as JavaPreviousElementShadowVariable) + super().__init__(JavaPreviousElementShadowVariable, + { + 'sourceVariableName': PythonClassTranslator.getJavaFieldName(source_variable_name) + }) + + +class NextElementShadowVariable(JavaAnnotation): + def __init__(self, *, + source_variable_name: str): + ensure_init() + from ai.timefold.jpyinterpreter import PythonClassTranslator + from ai.timefold.solver.core.api.domain.variable import ( + NextElementShadowVariable as JavaNextElementShadowVariable) + super().__init__(JavaNextElementShadowVariable, + { + 'sourceVariableName': PythonClassTranslator.getJavaFieldName(source_variable_name) + }) + + class AnchorShadowVariable(JavaAnnotation): def __init__(self, *, source_variable_name: str): @@ -265,13 +275,13 @@ def __init__(self, a_list): from ai.timefold.solver.core.api.domain.entity import PlanningEntity as JavaPlanningEntity def planning_entity_wrapper(entity_class_argument): - from .._timefold_java_interop import _generate_planning_entity_class + from .._timefold_java_interop import _add_to_compilation_queue from ai.timefold.solver.core.api.domain.entity import PinningFilter from jpyinterpreter import add_class_annotation, translate_python_bytecode_to_java_bytecode - from typing import get_type_hints, get_origin, Annotated - type_hints = get_type_hints(entity_class_argument, include_extras=True) + from typing import get_origin, Annotated + planning_pin_field = None - for name, type_hint in type_hints.items(): + for name, type_hint in entity_class_argument.__annotations__.items(): if get_origin(type_hint) == Annotated: for metadata in type_hint.__metadata__: if metadata == PlanningPin or isinstance(metadata, PlanningPin): @@ -293,7 +303,7 @@ def planning_entity_wrapper(entity_class_argument): out = add_class_annotation(JavaPlanningEntity, pinningFilter=pinning_filter_function)(entity_class_argument) - _generate_planning_entity_class(out) + _add_to_compilation_queue(out) return out if entity_class: # Called as @planning_entity @@ -330,10 +340,10 @@ def __init__(self, a_list): """ ensure_init() from jpyinterpreter import add_class_annotation - from .._timefold_java_interop import _generate_planning_solution_class + from .._timefold_java_interop import _add_to_compilation_queue from ai.timefold.solver.core.api.domain.solution import PlanningSolution as JavaPlanningSolution out = add_class_annotation(JavaPlanningSolution)(planning_solution_class) - _generate_planning_solution_class(planning_solution_class) + _add_to_compilation_queue(planning_solution_class) return out @@ -360,15 +370,6 @@ def nearby_distance_meter(distance_function: Callable[[Origin_, Destination_], f def constraint_provider(constraint_provider_function: Callable[[ConstraintFactory], List['_Constraint']], /) \ -> Callable[[ConstraintFactory], List['_Constraint']]: - """Marks a function as a ConstraintProvider. - - The function takes a single parameter, the ConstraintFactory, and - must return a list of Constraints. - To create a Constraint, start with ConstraintFactory.from(get_class(PythonClass)). - - :type constraint_provider_function: Callable[[ConstraintFactory], List[Constraint]] - :rtype: Callable[[ConstraintFactory], List[Constraint]] - """ ensure_init() def constraint_provider_wrapper(function): @@ -522,9 +523,10 @@ def wrapper_doChange(self, solution, problem_change_director): __all__ = ['PlanningId', 'PlanningScore', 'PlanningPin', 'PlanningVariable', - 'PlanningListVariable', 'PlanningVariableReference', 'ShadowVariable', + 'PlanningListVariable', 'ShadowVariable', 'PiggybackShadowVariable', - 'IndexShadowVariable', 'AnchorShadowVariable', 'InverseRelationShadowVariable', + 'IndexShadowVariable', 'PreviousElementShadowVariable', 'NextElementShadowVariable', + 'AnchorShadowVariable', 'InverseRelationShadowVariable', 'ProblemFactProperty', 'ProblemFactCollectionProperty', 'PlanningEntityProperty', 'PlanningEntityCollectionProperty', 'ValueRangeProvider', 'DeepPlanningClone', 'ConstraintConfigurationProvider', diff --git a/timefold-solver-python-core/src/main/python/api/_solver_manager.py b/timefold-solver-python-core/src/main/python/api/_solver_manager.py index 7ce6fe14..9c6d6ceb 100644 --- a/timefold-solver-python-core/src/main/python/api/_solver_manager.py +++ b/timefold-solver-python-core/src/main/python/api/_solver_manager.py @@ -6,7 +6,7 @@ from asyncio import Future from typing import TypeVar, Generic, Callable, TYPE_CHECKING from datetime import timedelta -from enum import Enum, auto as auto_enum +from enum import Enum if TYPE_CHECKING: # These imports require a JVM to be running, so only import if type checking @@ -19,9 +19,9 @@ class SolverStatus(Enum): - NOT_SOLVING = auto_enum() - SOLVING_SCHEDULED = auto_enum() - SOLVING_ACTIVE = auto_enum() + NOT_SOLVING = 'NOT_SOLVING' + SOLVING_SCHEDULED = 'SOLVING_SCHEDULED' + SOLVING_ACTIVE = 'SOLVING_ACTIVE' @staticmethod def _from_java_enum(enum_value): diff --git a/timefold-solver-python-core/src/main/python/config/_config.py b/timefold-solver-python-core/src/main/python/config/_config.py index 43c46b11..bde7cf10 100644 --- a/timefold-solver-python-core/src/main/python/config/_config.py +++ b/timefold-solver-python-core/src/main/python/config/_config.py @@ -3,7 +3,7 @@ from typing import Any, Optional, List, Type, Callable, TypeVar, Generic, TYPE_CHECKING from dataclasses import dataclass, field -from enum import Enum, auto +from enum import Enum from pathlib import Path from jpype import JClass @@ -65,28 +65,28 @@ def _to_java_duration(self) -> '_JavaDuration': class EnvironmentMode(Enum): - NON_REPRODUCIBLE = auto() - REPRODUCIBLE = auto() - FAST_ASSERT = auto() - NON_INTRUSIVE_FULL_ASSERT = auto() - FULL_ASSERT = auto() - TRACKED_FULL_ASSERT = auto() + NON_REPRODUCIBLE = 'NON_REPRODUCIBLE' + REPRODUCIBLE = 'REPRODUCIBLE' + FAST_ASSERT = 'FAST_ASSERT' + NON_INTRUSIVE_FULL_ASSERT = 'NON_INTRUSIVE_FULL_ASSERT' + FULL_ASSERT = 'FULL_ASSERT' + TRACKED_FULL_ASSERT = 'TRACKED_FULL_ASSERT' def _get_java_enum(self): return _lookup_on_java_class(_java_environment_mode, self.name) class TerminationCompositionStyle(Enum): - OR = auto() - AND = auto() + OR = 'OR' + AND = 'AND' def _get_java_enum(self): return _lookup_on_java_class(_java_termination_composition_style, self.name) class MoveThreadCount(Enum): - AUTO = auto() - NONE = auto() + AUTO = 'AUTO' + NONE = 'NONE' class RequiresEnterpriseError(EnvironmentError): @@ -122,11 +122,14 @@ def create_from_xml_text(xml_text: str) -> 'SolverConfig': return SolverConfig(xml_source_text=xml_text) def _to_java_solver_config(self) -> '_JavaSolverConfig': - from .._timefold_java_interop import OverrideClassLoader, get_class + from .._timefold_java_interop import OverrideClassLoader, get_class, _process_compilation_queue from ai.timefold.solver.core.config.solver import SolverConfig as JavaSolverConfig from java.io import File, ByteArrayInputStream # noqa from java.lang import IllegalArgumentException from java.util import ArrayList + + _process_compilation_queue() + out = JavaSolverConfig() with OverrideClassLoader() as class_loader: out.setClassLoader(class_loader)