diff --git a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java index ee9a930f46..9eda673f87 100644 --- a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java +++ b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java @@ -160,43 +160,7 @@ public void testNearestNeighbors() { } } - public void testNearestNeighborsPerformance() { - int n = 1_000_000; - int k = 100; - KdTree tree = new KdTree(); - Random rand = new Random(1); - - List points = new ArrayList<>(); - for (int i = 0; i < n; i++) { - double x = rand.nextDouble(); - double y = rand.nextDouble(); - points.add(new Coordinate(x, y)); - } - long startTime = System.nanoTime(); - for (Coordinate coordinate : points) { - tree.insert(coordinate); - } - long insertTime = System.nanoTime() - startTime; - System.out.println("Time to insert " + n + " points: " + (insertTime / 1_000_000) + " ms"); - - Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble()); - - // Time k-NN query using k-d tree - startTime = System.nanoTime(); - List nearest = tree.nearestNeighbors(query, k); - long knnTime = System.nanoTime() - startTime; - System.out.println("Time to find " + k + " nearest neighbors using k-d tree: " + (knnTime / 1_000_000) + " ms"); - // Time k-NN query using brute-force - startTime = System.nanoTime(); - List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k); - long bruteForceTime = System.nanoTime() - startTime; - System.out.println("Time to find " + k + " nearest neighbors using brute-force: " + (bruteForceTime / 1_000_000) + " ms"); - - for (int i = 0; i < k; i++) { - assertEquals(bruteForceNearest.get(i), nearest.get(i).getCoordinate()); - } - } public void testCollectNodes() { int n = 1000; diff --git a/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java b/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java index ed1dbf940b..65dfc0329f 100644 --- a/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java +++ b/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java @@ -1,7 +1,15 @@ package test.jts.perf.index; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + import org.locationtech.jts.geom.Coordinate; import org.locationtech.jts.geom.Envelope; +import org.locationtech.jts.index.kdtree.KdNode; import org.locationtech.jts.index.kdtree.KdTree; /** @@ -35,6 +43,42 @@ private void run() { index.query(env); } System.out.format("Queries complete\n"); + + testNearestNeighborsPerformance(); + } + + private void testNearestNeighborsPerformance() { + int n = 1_000_000; + int k = 1000; + KdTree tree = new KdTree(); + Random rand = new Random(); + + List points = new ArrayList<>(); + for (int i = 0; i < n; i++) { + double x = rand.nextDouble(); + double y = rand.nextDouble(); + points.add(new Coordinate(x, y)); + } + long startTime = System.nanoTime(); + for (Coordinate coordinate : points) { + tree.insert(coordinate); + } + long insertTime = System.nanoTime() - startTime; + System.out.println("Time to insert " + n + " points: " + (insertTime / 1_000_000) + " ms"); + + Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble()); + + // Time k-NN query using k-d tree + startTime = System.nanoTime(); + List nearest = tree.nearestNeighbors(query, k); + long knnTime = System.nanoTime() - startTime; + System.out.println("Time to find " + k + " nearest neighbors using k-d tree: " + (knnTime / 1_000_000) + " ms"); + + // Time k-NN query using brute-force + startTime = System.nanoTime(); + List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k); + long bruteForceTime = System.nanoTime() - startTime; + System.out.println("Time to find " + k + " nearest neighbors using brute-force: " + (bruteForceTime / 1_000_000) + " ms"); } /** @@ -52,4 +96,18 @@ private KdTree createUnbalancedTree(int numPts) { } return index; } + + private List bruteForceNearestNeighbors(KdTree tree, Coordinate query, int k) { + List allPoints = getAllPoints(tree); + + // Sort all points by distance to the query point + allPoints.sort(Comparator.comparingDouble(point -> query.distance(point))); + + // Return the first k points (ordered closest first) + return allPoints.subList(0, Math.min(k, allPoints.size())); + } + + private List getAllPoints(KdTree tree) { + return Arrays.stream(KdTree.toCoordinates(tree.getNodes())).collect(Collectors.toList()); + } }