Skip to content
This repository was archived by the owner on Jul 17, 2024. It is now read-only.

chore: Handle forward references, repeatable annotations, and use str enums #43

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package ai.timefold.jpyinterpreter;

import java.lang.annotation.Annotation;
import java.lang.annotation.Repeatable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.objectweb.asm.AnnotationVisitor;
Expand All @@ -23,6 +27,30 @@ public void addAnnotationTo(MethodVisitor methodVisitor) {
visitAnnotation(methodVisitor.visitAnnotation(Type.getDescriptor(annotationType), true));
}

public static List<AnnotationMetadata> getAnnotationListWithoutRepeatable(List<AnnotationMetadata> metadata) {
List<AnnotationMetadata> out = new ArrayList<>();
Map<Class<? extends Annotation>, List<AnnotationMetadata>> repeatableAnnotationMap = new LinkedHashMap<>();
for (AnnotationMetadata annotation : metadata) {
Repeatable repeatable = annotation.annotationType().getAnnotation(Repeatable.class);
if (repeatable == null) {
out.add(annotation);
continue;
}
var annotationContainer = repeatable.value();
repeatableAnnotationMap.computeIfAbsent(annotationContainer,
ignored -> new ArrayList<>()).add(annotation);
}
for (var entry : repeatableAnnotationMap.entrySet()) {
out.add(new AnnotationMetadata(entry.getKey(),
Map.of("value", entry.getValue().toArray(AnnotationMetadata[]::new))));
}
return out;
}

public static Type getValueAsType(String className) {
return Type.getType("L" + className.replace('.', '/') + ";");
}

private void visitAnnotation(AnnotationVisitor annotationVisitor) {
for (var entry : annotationValueMap.entrySet()) {
var annotationAttributeName = entry.getKey();
Expand All @@ -42,8 +70,8 @@ private void visitAnnotationAttribute(AnnotationVisitor annotationVisitor, Strin
return;
}

if (attributeValue instanceof Class<?> clazz) {
annotationVisitor.visit(attributeName, Type.getType(clazz));
if (attributeValue instanceof Type type) {
annotationVisitor.visit(attributeName, type);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import ai.timefold.jpyinterpreter.types.collections.PythonLikeTuple;
import ai.timefold.jpyinterpreter.types.wrappers.JavaObjectWrapper;
import ai.timefold.jpyinterpreter.types.wrappers.OpaquePythonReference;
import ai.timefold.jpyinterpreter.util.JavaPythonClassWriter;
import ai.timefold.jpyinterpreter.util.arguments.ArgumentSpec;

import org.objectweb.asm.ClassWriter;
Expand Down Expand Up @@ -122,7 +123,8 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp

for (Class<?> javaInterface : pythonCompiledClass.javaInterfaces) {
javaInterfaceImplementorSet.add(
new DelegatingInterfaceImplementor(internalClassName, javaInterface, instanceMethodNameToMethodDescriptor));
new DelegatingInterfaceImplementor(internalClassName, javaInterface,
instanceMethodNameToMethodDescriptor));
}

if (pythonCompiledClass.superclassList.isEmpty()) {
Expand Down Expand Up @@ -173,14 +175,14 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
interfaces[i] = Type.getInternalName(nonObjectInterfaceImplementors.get(i).getInterfaceClass());
}

ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);

classWriter.visit(Opcodes.V11, Modifier.PUBLIC, internalClassName, null,
superClassType.getJavaTypeInternalName(), interfaces);

classWriter.visitSource(pythonCompiledClass.moduleFilePath, null);

for (var annotation : pythonCompiledClass.annotations) {
for (var annotation : AnnotationMetadata.getAnnotationListWithoutRepeatable(pythonCompiledClass.annotations)) {
annotation.addAnnotationTo(classWriter);
}

Expand Down Expand Up @@ -208,19 +210,23 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
var typeHint = pythonCompiledClass.typeAnnotations.getOrDefault(attributeName,
TypeHint.withoutAnnotations(BuiltinTypes.BASE_TYPE));
PythonLikeType type = typeHint.type();
PythonLikeType javaGetterType = typeHint.javaGetterType();
if (type == null) { // null might be in __annotations__
type = BuiltinTypes.BASE_TYPE;
}

attributeNameToTypeMap.put(attributeName, type);
FieldVisitor fieldVisitor;
String javaFieldTypeDescriptor;
String getterTypeDescriptor;
String signature = null;
boolean isJavaType;
if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) {
javaFieldTypeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType());
fieldVisitor = classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
null, null);
getterTypeDescriptor = javaFieldTypeDescriptor;
fieldVisitor =
classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
null, null);
isJavaType = true;
} else {
if (typeHint.genericArgs() != null) {
Expand All @@ -229,16 +235,21 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
signature = signatureWriter.toString();
}
javaFieldTypeDescriptor = 'L' + type.getJavaTypeInternalName() + ';';
fieldVisitor = classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
signature, null);
getterTypeDescriptor = javaGetterType.getJavaTypeDescriptor();
fieldVisitor =
classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
signature, null);
isJavaType = false;
}
fieldVisitor.visitEnd();

createJavaGetterSetter(classWriter, preparedClassInfo,
attributeName,
Type.getType(javaFieldTypeDescriptor),
Type.getType(getterTypeDescriptor),
signature,
typeHint);

FieldDescriptor fieldDescriptor =
new FieldDescriptor(attributeName, getJavaFieldName(attributeName), internalClassName,
javaFieldTypeDescriptor, type, true, isJavaType);
Expand Down Expand Up @@ -344,7 +355,8 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
pythonLikeType.$setAttribute("__module__", PythonString.valueOf(pythonCompiledClass.module));

PythonLikeDict annotations = new PythonLikeDict();
pythonCompiledClass.typeAnnotations.forEach((name, type) -> annotations.put(PythonString.valueOf(name), type.type()));
pythonCompiledClass.typeAnnotations
.forEach((name, type) -> annotations.put(PythonString.valueOf(name), type.type()));
pythonLikeType.$setAttribute("__annotations__", annotations);

PythonLikeTuple mro = new PythonLikeTuple();
Expand Down Expand Up @@ -552,7 +564,7 @@ private static Class<?> createPythonWrapperMethod(String methodName, PythonCompi
String className = maybeClassName;
String internalClassName = className.replace('.', '/');

ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);

classWriter.visit(Opcodes.V11, Modifier.PUBLIC, internalClassName, null,
Type.getInternalName(Object.class), new String[] { interfaceDeclaration.interfaceName });
Expand Down Expand Up @@ -663,7 +675,7 @@ private static PythonLikeFunction createConstructor(String classInternalName,
String constructorClassName = maybeClassName;
String constructorInternalClassName = constructorClassName.replace('.', '/');

ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
classWriter.visit(Opcodes.V11, Modifier.PUBLIC, constructorInternalClassName, null, Type.getInternalName(Object.class),
new String[] {
Type.getInternalName(PythonLikeFunction.class)
Expand Down Expand Up @@ -775,24 +787,24 @@ private static PythonLikeFunction createConstructor(String classInternalName,

private static void createJavaGetterSetter(ClassWriter classWriter,
PreparedClassInfo preparedClassInfo,
String attributeName, Type attributeType,
String attributeName, Type attributeType, Type getterType,
String signature,
TypeHint typeHint) {
createJavaGetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint);
createJavaSetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint);
createJavaGetter(classWriter, preparedClassInfo, attributeName, attributeType, getterType, signature, typeHint);
createJavaSetter(classWriter, preparedClassInfo, attributeName, attributeType, getterType, signature, typeHint);
}

private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName,
Type attributeType, String signature, TypeHint typeHint) {
Type attributeType, Type getterType, String signature, TypeHint typeHint) {
var getterName = "get" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
if (signature != null) {
if (signature != null && Objects.equals(attributeType, getterType)) {
signature = "()" + signature;
}
var getterVisitor = classWriter.visitMethod(Modifier.PUBLIC, getterName, Type.getMethodDescriptor(attributeType),
var getterVisitor = classWriter.visitMethod(Modifier.PUBLIC, getterName, Type.getMethodDescriptor(getterType),
signature, null);
var maxStack = 1;

for (var annotation : typeHint.annotationList()) {
for (var annotation : AnnotationMetadata.getAnnotationListWithoutRepeatable(typeHint.annotationList())) {
annotation.addAnnotationTo(getterVisitor);
}

Expand All @@ -813,19 +825,22 @@ private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo
// If branch is taken, stack is field
// If branch is not taken, stack is null
}
if (!Objects.equals(attributeType, getterType)) {
getterVisitor.visitTypeInsn(Opcodes.CHECKCAST, getterType.getInternalName());
}
getterVisitor.visitInsn(Opcodes.ARETURN);
getterVisitor.visitMaxs(maxStack, 0);
getterVisitor.visitEnd();
}

private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName,
Type attributeType, String signature, TypeHint typeHint) {
Type attributeType, Type setterType, String signature, TypeHint typeHint) {
var setterName = "set" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
if (signature != null) {
if (signature != null && Objects.equals(attributeType, setterType)) {
signature = "(" + signature + ")V";
}
var setterVisitor = classWriter.visitMethod(Modifier.PUBLIC, setterName, Type.getMethodDescriptor(Type.VOID_TYPE,
attributeType),
setterType),
signature, null);
var maxStack = 2;
setterVisitor.visitCode();
Expand All @@ -845,6 +860,9 @@ private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo
// If branch is taken, stack is (non-null instance)
// If branch is not taken, stack is None
}
if (!Objects.equals(attributeType, setterType)) {
setterVisitor.visitTypeInsn(Opcodes.CHECKCAST, attributeType.getInternalName());
}
setterVisitor.visitFieldInsn(Opcodes.PUTFIELD, preparedClassInfo.classInternalName,
attributeName, attributeType.getDescriptor());
setterVisitor.visitInsn(Opcodes.RETURN);
Expand All @@ -855,7 +873,7 @@ private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo
private static void addAnnotationsToMethod(PythonCompiledFunction function, MethodVisitor methodVisitor) {
var returnTypeHint = function.typeAnnotations.get("return");
if (returnTypeHint != null) {
for (var annotation : returnTypeHint.annotationList()) {
for (var annotation : AnnotationMetadata.getAnnotationListWithoutRepeatable(returnTypeHint.annotationList())) {
annotation.addAnnotationTo(methodVisitor);
}
}
Expand Down Expand Up @@ -1444,7 +1462,7 @@ public static InterfaceDeclaration createInterfaceForFunctionSignature(FunctionS

String internalClassName = className.replace('.', '/');

ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
ClassWriter classWriter = new JavaPythonClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
classWriter.visit(Opcodes.V11, Modifier.PUBLIC | Modifier.INTERFACE | Modifier.ABSTRACT, internalClassName, null,
Type.getInternalName(Object.class), null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,16 @@ private static <T> Class<T> getParameterJavaClass(List<PythonLikeType> parameter
return (Class) parameterTypeList.get(variableIndex).getJavaClassOrDefault(PythonLikeObject.class);
}

private static String getParameterJavaClassName(List<PythonLikeType> parameterTypeList, int variableIndex) {
return parameterTypeList.get(variableIndex).getJavaTypeInternalName();
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public BiFunction<PythonLikeTuple, PythonLikeDict, ArgumentSpec<PythonLikeObject>> getArgumentSpecMapper() {
return (defaultPositionalArguments, defaultKeywordArguments) -> {
ArgumentSpec<PythonLikeObject> out = ArgumentSpec.forFunctionReturning(qualifiedName, getReturnType()
.map(type -> (Class) type.getJavaClassOrDefault(PythonLikeObject.class))
.orElse(PythonLikeObject.class));
.map(PythonLikeType::getJavaTypeInternalName)
.orElse(PythonLikeObject.class.getName()));

int variableIndex = 0;
int defaultPositionalStartIndex = co_argcount - defaultPositionalArguments.size();
Expand All @@ -226,23 +230,23 @@ public BiFunction<PythonLikeTuple, PythonLikeDict, ArgumentSpec<PythonLikeObject
for (; variableIndex < co_posonlyargcount; variableIndex++) {
if (variableIndex >= defaultPositionalStartIndex) {
out = out.addPositionalOnlyArgument(co_varnames.get(variableIndex),
getParameterJavaClass(parameterTypeList, variableIndex),
getParameterJavaClassName(parameterTypeList, variableIndex),
defaultPositionalArguments.get(
variableIndex - defaultPositionalStartIndex));
} else {
out = out.addPositionalOnlyArgument(co_varnames.get(variableIndex),
getParameterJavaClass(parameterTypeList, variableIndex));
getParameterJavaClassName(parameterTypeList, variableIndex));
}
}

for (; variableIndex < co_argcount; variableIndex++) {
if (variableIndex >= defaultPositionalStartIndex) {
out = out.addArgument(co_varnames.get(variableIndex),
getParameterJavaClass(parameterTypeList, variableIndex),
getParameterJavaClassName(parameterTypeList, variableIndex),
defaultPositionalArguments.get(variableIndex - defaultPositionalStartIndex));
} else {
out = out.addArgument(co_varnames.get(variableIndex),
getParameterJavaClass(parameterTypeList, variableIndex));
getParameterJavaClassName(parameterTypeList, variableIndex));
}
}

Expand All @@ -251,11 +255,11 @@ public BiFunction<PythonLikeTuple, PythonLikeDict, ArgumentSpec<PythonLikeObject
defaultKeywordArguments.get(PythonString.valueOf(co_varnames.get(variableIndex)));
if (maybeDefault != null) {
out = out.addKeywordOnlyArgument(co_varnames.get(variableIndex),
getParameterJavaClass(parameterTypeList, variableIndex),
getParameterJavaClassName(parameterTypeList, variableIndex),
maybeDefault);
} else {
out = out.addKeywordOnlyArgument(co_varnames.get(variableIndex),
getParameterJavaClass(parameterTypeList, variableIndex));
getParameterJavaClassName(parameterTypeList, variableIndex));
}
variableIndex++;
}
Expand Down
Loading