diff --git a/checker/src/main/java/org/checkerframework/checker/nullness/KeyForAnnotatedTypeFactory.java b/checker/src/main/java/org/checkerframework/checker/nullness/KeyForAnnotatedTypeFactory.java index 46a2dff7a85..8cc4a79fc2f 100644 --- a/checker/src/main/java/org/checkerframework/checker/nullness/KeyForAnnotatedTypeFactory.java +++ b/checker/src/main/java/org/checkerframework/checker/nullness/KeyForAnnotatedTypeFactory.java @@ -9,10 +9,12 @@ import java.util.Collection; import java.util.Collections; import java.util.LinkedHashSet; +import java.util.List; import java.util.Set; import javax.lang.model.element.AnnotationMirror; import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; +import javax.lang.model.type.TypeKind; import org.checkerframework.checker.nullness.qual.KeyFor; import org.checkerframework.checker.nullness.qual.KeyForBottom; import org.checkerframework.checker.nullness.qual.PolyKeyFor; @@ -27,6 +29,7 @@ import org.checkerframework.framework.flow.CFAbstractAnalysis; import org.checkerframework.framework.type.AnnotatedTypeFactory; import org.checkerframework.framework.type.AnnotatedTypeMirror; +import org.checkerframework.framework.type.AnnotatedTypeMirror.AnnotatedDeclaredType; import org.checkerframework.framework.type.AnnotatedTypeMirror.AnnotatedExecutableType; import org.checkerframework.framework.type.GenericAnnotatedTypeFactory; import org.checkerframework.framework.type.QualifierHierarchy; @@ -65,6 +68,10 @@ public class KeyForAnnotatedTypeFactory private final ExecutableElement mapPut = TreeUtils.getMethod("java.util.Map", "put", 2, processingEnv); + /** The Map.keySet method. */ + private final ExecutableElement mapKeySet = + TreeUtils.getMethod("java.util.Map", "keySet", 0, processingEnv); + /** The KeyFor.value field/element. */ protected final ExecutableElement keyForValueElement = TreeUtils.getMethod(KeyFor.class, "value", 0, processingEnv); @@ -237,6 +244,88 @@ protected DependentTypesHelper createDependentTypesHelper() { return new KeyForDependentTypesHelper(this); } + /** + * Override to merge KeyFor annotations from Map receiver's type arguments into keySet() return + * type. + */ + @Override + protected void addComputedTypeAnnotations(Tree tree, AnnotatedTypeMirror type, boolean iUseFlow) { + super.addComputedTypeAnnotations(tree, type, iUseFlow); + + // Handle keySet() method invocations: merge KeyFor annotations from Map receiver's type + // arguments into the return type's type arguments + if (tree instanceof MethodInvocationTree) { + MethodInvocationTree methodInvocation = (MethodInvocationTree) tree; + if (TreeUtils.isMethodInvocation(methodInvocation, mapKeySet, getProcessingEnv())) { + if (type.getKind() == TypeKind.DECLARED) { + AnnotatedDeclaredType keySetReturnType = (AnnotatedDeclaredType) type; + ExpressionTree receiver = TreeUtils.getReceiverTree(methodInvocation); + if (receiver != null) { + AnnotatedTypeMirror receiverType = getAnnotatedType(receiver); + if (receiverType.getKind() == TypeKind.DECLARED) { + AnnotatedDeclaredType receiverDeclaredType = (AnnotatedDeclaredType) receiverType; + mergeKeyForFromMapReceiverIntoKeySetReturn(receiverDeclaredType, keySetReturnType); + } + } + } + } + } + } + + /** + * Merges KeyFor annotations from the Map receiver's first type argument (key type) into the Set's + * first type argument (element type) in the keySet() return type. + * + * @param mapReceiverType the type of the Map receiver (e.g., Map<@KeyFor("m") String, + * Integer>) + * @param keySetReturnType the return type of keySet() (e.g., Set<@KeyFor("mapVar") String>) + */ + private void mergeKeyForFromMapReceiverIntoKeySetReturn( + AnnotatedDeclaredType mapReceiverType, AnnotatedDeclaredType keySetReturnType) { + // Get the Map's first type argument (the key type) + List mapTypeArgs = mapReceiverType.getTypeArguments(); + if (mapTypeArgs.isEmpty()) { + return; + } + AnnotatedTypeMirror mapKeyType = mapTypeArgs.get(0); + + // Get the Set's first type argument (the element type) + List setTypeArgs = keySetReturnType.getTypeArguments(); + if (setTypeArgs.isEmpty()) { + return; + } + AnnotatedTypeMirror setElementType = setTypeArgs.get(0); + + // Extract KeyFor annotation from the Map's key type + AnnotationMirror mapKeyKeyFor = mapKeyType.getEffectiveAnnotation(KeyFor.class); + if (mapKeyKeyFor == null) { + return; + } + + // Get the KeyFor values from the Map's key type + List mapKeyForValues = + AnnotationUtils.getElementValueArray(mapKeyKeyFor, keyForValueElement, String.class); + + // Extract KeyFor annotation from the Set's element type + AnnotationMirror setElementKeyFor = setElementType.getEffectiveAnnotation(KeyFor.class); + + // Collect all KeyFor values + Set mergedKeyForValues = new LinkedHashSet<>(mapKeyForValues); + + // Add existing KeyFor values from the Set's element type + if (setElementKeyFor != null) { + List setKeyForValues = + AnnotationUtils.getElementValueArray(setElementKeyFor, keyForValueElement, String.class); + mergedKeyForValues.addAll(setKeyForValues); + } + + // Create a new KeyFor annotation with merged values + if (!mergedKeyForValues.isEmpty()) { + AnnotationMirror mergedKeyFor = createKeyForAnnotationMirrorWithValue(mergedKeyForValues); + setElementType.replaceAnnotation(mergedKeyFor); + } + } + /** * Converts KeyFor annotations with errors into {@code @UnknownKeyFor} in the type of method * invocations. This changes all qualifiers on the type of a method invocation expression, even diff --git a/checker/tests/nullness/KeyForMultiple.class b/checker/tests/nullness/KeyForMultiple.class new file mode 100644 index 00000000000..99c2c607a6a Binary files /dev/null and b/checker/tests/nullness/KeyForMultiple.class differ diff --git a/checker/tests/nullness/KeyForMultiple.java b/checker/tests/nullness/KeyForMultiple.java index 58c4ac6029e..1f0033f1816 100644 --- a/checker/tests/nullness/KeyForMultiple.java +++ b/checker/tests/nullness/KeyForMultiple.java @@ -1,7 +1,5 @@ // Test case for issue #2358: https://tinyurl.com/cfissue/#2358 -// @skip-test until the bug is fixed. - import java.util.HashMap; import java.util.Map; import java.util.Set; @@ -14,6 +12,7 @@ void m1() { Map<@KeyFor({"sharedBooks"}) String, Integer> sharedBooks = new HashMap<>(); Map<@KeyFor({"sharedBooks"}) String, Integer> sharedCounts1 = new HashMap<>(); + // :: error: (assignment) Set<@KeyFor({"sharedCounts1"}) String> sharedCountsKeys1 = sharedCounts1.keySet(); } @@ -30,6 +29,7 @@ void m3() { Map<@KeyFor({"sharedBooks"}) String, Integer> sharedBooks = new HashMap<>(); Map<@KeyFor({"sharedBooks", "sharedCounts2"}) String, Integer> sharedCounts2 = new HashMap<>(); + // :: error: (assignment) Set<@KeyFor({"sharedCounts2"}) String> sharedCountsKeys2 = sharedCounts2.keySet(); }