From 823351dc30d6cb0a13a9061346a296a0ab5773f2 Mon Sep 17 00:00:00 2001 From: Michael Carleton Date: Mon, 27 Jan 2025 03:12:56 +0000 Subject: [PATCH 1/6] add KDtree nearestNeighbor() and nearestNeigbors() --- .../locationtech/jts/index/kdtree/KdTree.java | 193 +++++++++++++++++- .../jts/index/kdtree/KdTreeTest.java | 160 +++++++++++++++ 2 files changed, 344 insertions(+), 9 deletions(-) diff --git a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java index ad40b75f25..c9ac57d10f 100644 --- a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java +++ b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java @@ -15,9 +15,11 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; +import java.util.PriorityQueue; import org.locationtech.jts.geom.Coordinate; import org.locationtech.jts.geom.CoordinateList; @@ -29,11 +31,11 @@ * over two dimensions (X and Y). * KD-trees provide fast range searching and fast lookup for point data. * The tree is built dynamically by inserting points. - * The tree supports queries by range and for point equality. - * For querying an internal stack is used instead of recursion to avoid overflow. + * The tree supports queries by location and range, and for point equality. + * For querying, an internal stack is used instead of recursion to avoid overflow. *

* This implementation supports detecting and snapping points which are closer - * than a given distance tolerance. + * than a given distance tolerance. * If the same point (up to tolerance) is inserted * more than once, it is snapped to the existing node. * In other words, if a point is inserted which lies @@ -179,6 +181,180 @@ public KdNode insert(Coordinate p, Object data) { return insertExact(p, data); } + + /** + * Finds the nearest node in the tree to the given query point. + * + * @param query the query point + * @return the nearest node, or null if the tree is empty + */ + public KdNode nearestNeighbor(final Coordinate query) { + if (root == null) { + return null; + } + + KdNode bestNode = null; + double bestDistance = Double.POSITIVE_INFINITY; + Deque stack = new ArrayDeque<>(); + KdNode currentNode = root; + boolean isXLevel = true; + + while (currentNode != null || !stack.isEmpty()) { + if (currentNode != null) { + double currentDist = query.distanceSq(currentNode.getCoordinate()); + if (currentDist < bestDistance) { + bestNode = currentNode; + bestDistance = currentDist; + } + + boolean currentIsXLevel = isXLevel; + double splitValue = currentNode.splitValue(currentIsXLevel); + KdNode nextNode; + KdNode otherNode; + + if (currentIsXLevel) { + if (query.x < splitValue) { + nextNode = currentNode.getLeft(); + otherNode = currentNode.getRight(); + } else { + nextNode = currentNode.getRight(); + otherNode = currentNode.getLeft(); + } + } else { + if (query.y < splitValue) { + nextNode = currentNode.getLeft(); + otherNode = currentNode.getRight(); + } else { + nextNode = currentNode.getRight(); + otherNode = currentNode.getLeft(); + } + } + + stack.push(new NNStackFrame(otherNode, currentIsXLevel, splitValue)); + currentNode = nextNode; + isXLevel = !currentIsXLevel; + } else { + NNStackFrame frame = stack.pop(); + KdNode otherNode = frame.node; + boolean parentSplitAxis = frame.parentSplitAxis; + double parentSplitValue = frame.parentSplitValue; + + double diff = parentSplitAxis + ? query.x - parentSplitValue + : query.y - parentSplitValue; + double distanceToSplitSq = diff * diff; + + if (distanceToSplitSq < bestDistance) { + currentNode = otherNode; + isXLevel = !parentSplitAxis; + } else { + currentNode = null; + } + } + } + + return bestNode; + } + + /** + * Finds the nearest N nodes in the tree to the given query point. + * + * @param query the query point + * @param n the number of nearest nodes to find + * @return a list of the nearest nodes, sorted by distance (closest first), or an empty list if the tree is empty. + */ + public List nearestNeighbors(final Coordinate query, final int n) { + if (root == null || n <= 0) { + return Collections.emptyList(); + } + + PriorityQueue heap = new PriorityQueue<>(n, (n1, n2) -> + Double.compare(query.distanceSq(n2.getCoordinate()), query.distanceSq(n1.getCoordinate())) + ); + + Deque stack = new ArrayDeque<>(); + KdNode currentNode = root; + boolean isXLevel = true; + + while (currentNode != null || !stack.isEmpty()) { + if (currentNode != null) { + double currentDist = query.distanceSq(currentNode.getCoordinate()); + if (heap.size() < n) { + heap.offer(currentNode); + } else { + double maxDist = query.distanceSq(heap.peek().getCoordinate()); + if (currentDist < maxDist) { + heap.poll(); + heap.offer(currentNode); + } + } + + boolean currentIsXLevel = isXLevel; + double splitValue = currentNode.splitValue(currentIsXLevel); + KdNode nextNode; + KdNode otherNode; + + if (currentIsXLevel) { + if (query.x < splitValue) { + nextNode = currentNode.getLeft(); + otherNode = currentNode.getRight(); + } else { + nextNode = currentNode.getRight(); + otherNode = currentNode.getLeft(); + } + } else { + if (query.y < splitValue) { + nextNode = currentNode.getLeft(); + otherNode = currentNode.getRight(); + } else { + nextNode = currentNode.getRight(); + otherNode = currentNode.getLeft(); + } + } + + stack.push(new NNStackFrame(otherNode, currentIsXLevel, splitValue)); + currentNode = nextNode; + isXLevel = !currentIsXLevel; + } else { + NNStackFrame frame = stack.pop(); + KdNode otherNode = frame.node; + boolean parentSplitAxis = frame.parentSplitAxis; + double parentSplitValue = frame.parentSplitValue; + + double diff = parentSplitAxis + ? query.x - parentSplitValue + : query.y - parentSplitValue; + double distanceToSplitSq = diff * diff; + + double currentMaxDist = heap.isEmpty() ? Double.POSITIVE_INFINITY : query.distanceSq(heap.peek().getCoordinate()); + + if (distanceToSplitSq < currentMaxDist || heap.size() < n) { + currentNode = otherNode; + isXLevel = !parentSplitAxis; + } else { + currentNode = null; + } + } + } + + List result = new ArrayList<>(heap); + Collections.sort(result, (n1, n2) -> + Double.compare(query.distanceSq(n1.getCoordinate()), query.distanceSq(n2.getCoordinate())) + ); + return result; + } + + private static class NNStackFrame { + KdNode node; + boolean parentSplitAxis; + double parentSplitValue; + + NNStackFrame(KdNode node, boolean parentSplitAxis, double parentSplitValue) { + this.node = node; + this.parentSplitAxis = parentSplitAxis; + this.parentSplitValue = parentSplitValue; + } + } /** * Finds the node in the tree which is the best match for a point @@ -189,10 +365,9 @@ public KdNode insert(Coordinate p, Object data) { * existing node. * * @param p the point being inserted - * @return the best matching node - * @return null if no match was found + * @return the best matching node. null if no match was found. */ - private KdNode findBestMatchNode(Coordinate p) { + public KdNode findBestMatchNode(Coordinate p) { BestMatchVisitor visitor = new BestMatchVisitor(p, tolerance); query(visitor.queryEnvelope(), visitor); return visitor.getNode(); @@ -377,8 +552,8 @@ public boolean isXLevel() { * @param queryEnv the range rectangle to query * @return a list of the KdNodes found */ - public List query(Envelope queryEnv) { - final List result = new ArrayList(); + public List query(Envelope queryEnv) { + final List result = new ArrayList(); query(queryEnv, result); return result; } @@ -391,7 +566,7 @@ public List query(Envelope queryEnv) { * @param result * a list to accumulate the result nodes into */ - public void query(Envelope queryEnv, final List result) { + public void query(Envelope queryEnv, final List result) { query(queryEnv, new KdNodeVisitor() { public void visit(KdNode node) { diff --git a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java index 79a46b1f53..34d98712ee 100644 --- a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java +++ b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java @@ -12,8 +12,12 @@ package org.locationtech.jts.index.kdtree; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; import java.util.List; +import java.util.Random; import org.locationtech.jts.geom.Coordinate; import org.locationtech.jts.geom.CoordinateArrays; @@ -103,6 +107,105 @@ public void testSizeDepth() { assertTrue( depth <= size ); } + public void testNearestNeighbor() { + int n = 1000; // Number of random points to seed + KdTree tree = new KdTree(); + Random rand = new Random(1337); + + // Seed n random points + for (int i = 0; i < n; i++) { + double x = rand.nextDouble() * 100; // Random x between 0 and 100 + double y = rand.nextDouble() * 100; // Random y between 0 and 100 + tree.insert(new Coordinate(x, y)); + } + + // Test 5 different query points + for (int i = 0; i < 500; i++) { + double queryX = rand.nextDouble() * 100; // Random query x between 0 and 100 + double queryY = rand.nextDouble() * 100; // Random query y between 0 and 100 + Coordinate query = new Coordinate(queryX, queryY); + + // Find nearest neighbor using k-d tree + KdNode nearestNode = tree.nearestNeighbor(query); + + // Find nearest neighbor using brute-force + Coordinate bruteForceNearest = bruteForceNearestNeighbor(tree, query); + + assertEquals(nearestNode.getCoordinate(), bruteForceNearest); + } + } + + public void testNearestNeighbors() { + int n = 100; // Number of random points to seed + KdTree tree = new KdTree(); + Random rand = new Random(0); + + // Seed n random points + for (int i = 0; i < n; i++) { + double x = rand.nextDouble() * 100; // Random x between 0 and 100 + double y = rand.nextDouble() * 100; // Random y between 0 and 100 + tree.insert(new Coordinate(x, y)); + } + + // Query point + Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble()); + int k = 50; + + // Find k-nearest neighbors using k-d tree + List nearestNodes = tree.nearestNeighbors(query, k); + + // Find k-nearest neighbors using brute-force + List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k); + + // Verify that both methods return the same results + assertEquals(k, nearestNodes.size()); + for (int i = 0; i < k; i++) { + assertEquals(bruteForceNearest.get(i), nearestNodes.get(i).getCoordinate()); + } + } + + public void testPerformance() { + int n = 1_000_000; // Number of random points to seed + int k = 100; // Number of nearest neighbors to find + KdTree tree = new KdTree(); + Random rand = new Random(1); + + // Seed n random points + List points = new ArrayList<>(); + for (int i = 0; i < n; i++) { + double x = rand.nextDouble(); // Random x between 0 and 100 + double y = rand.nextDouble(); // Random y between 0 and 100 + points.add(new Coordinate(x, y)); + } + long startTime = System.nanoTime(); + for (Coordinate coordinate : points) { + tree.insert(coordinate); + } + long insertTime = System.nanoTime() - startTime; + System.out.println("Time to insert " + n + " points: " + (insertTime / 1_000_000) + " ms"); + + // Generate a random query point + Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble()); + + // Time k-NN query using k-d tree + startTime = System.nanoTime(); + List nearest = tree.nearestNeighbors(query, k); + long knnTime = System.nanoTime() - startTime; + System.out.println("Time to find " + k + " nearest neighbors using k-d tree: " + (knnTime / 1_000_000) + " ms"); + + // Time k-NN query using brute-force + startTime = System.nanoTime(); + List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k); + long bruteForceTime = System.nanoTime() - startTime; + System.out.println("Time to find " + k + " nearest neighbors using brute-force: " + (bruteForceTime / 1_000_000) + " ms"); + + // Verify that both methods return the same results +// assertEquals(k, nearestNodes.size()); + for (int i = 0; i < k; i++) { + assertEquals(bruteForceNearest.get(i), nearest.get(i).getCoordinate()); + } + } + private void testQuery(String wktInput, double tolerance, Envelope queryEnv, String wktExpected) { KdTree index = build(wktInput, tolerance); @@ -155,6 +258,63 @@ private void testQuery(KdTree index, assertEquals("Point query not found", node.getCoordinate(), p); } } + + // Helper method to find the nearest neighbor using brute-force + private Coordinate bruteForceNearestNeighbor(KdTree tree, Coordinate query) { + List allPoints = getAllPoints(tree); + Coordinate nearest = null; + double minDistance = Double.POSITIVE_INFINITY; + + for (Coordinate point : allPoints) { + double distance = query.distance(point); + if (distance < minDistance) { + minDistance = distance; + nearest = point; + } + } + + return nearest; + } + + private Coordinate bruteForceNearestNeighbor(Collection allPoints, Coordinate query) { + Coordinate nearest = null; + double minDistance = Double.POSITIVE_INFINITY; + + for (Coordinate point : allPoints) { + double distance = query.distance(point); + if (distance < minDistance) { + minDistance = distance; + nearest = point; + } + } + + return nearest; + } + + private List bruteForceNearestNeighbors(KdTree tree, Coordinate query, int k) { + List allPoints = getAllPoints(tree); + + // Sort all points by distance to the query point + allPoints.sort(Comparator.comparingDouble(point -> query.distance(point))); + + // Return the first k points + return allPoints.subList(0, Math.min(k, allPoints.size())); + } + + private List getAllPoints(KdTree tree) { + List points = new ArrayList<>(); + collectPoints(tree.getRoot(), points); + return points; + } + + private void collectPoints(KdNode node, List points) { + if (node == null) { + return; + } + points.add(node.getCoordinate()); + collectPoints(node.getLeft(), points); + collectPoints(node.getRight(), points); + } private KdTree build(String wktInput, double tolerance) { final KdTree index = new KdTree(tolerance); From 6e9a81222de8da5a058f67934fca25ae4eff4e29 Mon Sep 17 00:00:00 2001 From: Michael Carleton Date: Mon, 27 Jan 2025 13:34:32 +0000 Subject: [PATCH 2/6] add getNodes() --- .../locationtech/jts/index/kdtree/KdTree.java | 29 +++++++++++++++++++ .../jts/index/kdtree/KdTreeTest.java | 15 ++++++++++ 2 files changed, 44 insertions(+) diff --git a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java index c9ac57d10f..aa0465ebf3 100644 --- a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java +++ b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java @@ -601,6 +601,35 @@ public KdNode query(Coordinate queryPt) { //-- point not found return null; } + + /** + * Performs an in-order traversal of the tree, collecting and returning all + * nodes that have been inserted. + * + * @return A list containing all nodes in the KdTree. Returns an empty list if + * the tree is empty. + */ + public List getNodes() { + List nodeList = new ArrayList<>(); + if (root == null) { + return nodeList; // empty list for empty tree + } + + Deque stack = new ArrayDeque<>(); + KdNode currentNode = root; + + while (currentNode != null || !stack.isEmpty()) { + if (currentNode != null) { + stack.push(currentNode); + currentNode = currentNode.getLeft(); + } else { + currentNode = stack.pop(); + nodeList.add(currentNode); + currentNode = currentNode.getRight(); + } + } + return nodeList; + } /** * Computes the depth of the tree. diff --git a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java index 34d98712ee..1c201451ff 100644 --- a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java +++ b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java @@ -206,6 +206,21 @@ public void testPerformance() { } } + public void testCollectNodes() { + int n = 1000; // Number of random points to seed + KdTree tree = new KdTree(); + Random rand = new Random(1337); + + // Seed n random points + for (int i = 0; i < n; i++) { + double x = rand.nextDouble() * 100; // Random x between 0 and 100 + double y = rand.nextDouble() * 100; // Random y between 0 and 100 + tree.insert(new Coordinate(x, y)); + } + + assertEquals(n, tree.getNodes().size()); + } + private void testQuery(String wktInput, double tolerance, Envelope queryEnv, String wktExpected) { KdTree index = build(wktInput, tolerance); From fa932186843949f09e9d3763a2c64057635b6bd4 Mon Sep 17 00:00:00 2001 From: Michael Carleton Date: Mon, 27 Jan 2025 13:37:22 +0000 Subject: [PATCH 3/6] declare kdnode collection type --- .../main/java/org/locationtech/jts/index/kdtree/KdTree.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java index aa0465ebf3..e33e7a640c 100644 --- a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java +++ b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java @@ -67,7 +67,7 @@ public class KdTree { * a collection of nodes * @return an array of the coordinates represented by the nodes */ - public static Coordinate[] toCoordinates(Collection kdnodes) { + public static Coordinate[] toCoordinates(Collection kdnodes) { return toCoordinates(kdnodes, false); } @@ -82,9 +82,9 @@ public static Coordinate[] toCoordinates(Collection kdnodes) { * be included multiple times * @return an array of the coordinates represented by the nodes */ - public static Coordinate[] toCoordinates(Collection kdnodes, boolean includeRepeated) { + public static Coordinate[] toCoordinates(Collection kdnodes, boolean includeRepeated) { CoordinateList coord = new CoordinateList(); - for (Iterator it = kdnodes.iterator(); it.hasNext();) { + for (Iterator it = kdnodes.iterator(); it.hasNext();) { KdNode node = (KdNode) it.next(); int count = includeRepeated ? node.getCount() : 1; for (int i = 0; i < count; i++) { From 53b72647e97fe9c7035836182a47b73795e562d1 Mon Sep 17 00:00:00 2001 From: Michael Carleton Date: Mon, 27 Jan 2025 13:52:58 +0000 Subject: [PATCH 4/6] amend kdtree NN tests --- .../jts/index/kdtree/KdTreeTest.java | 111 ++++++------------ 1 file changed, 38 insertions(+), 73 deletions(-) diff --git a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java index 1c201451ff..ee9a930f46 100644 --- a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java +++ b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java @@ -14,10 +14,10 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Comparator; import java.util.List; import java.util.Random; +import java.util.stream.Collectors; import org.locationtech.jts.geom.Coordinate; import org.locationtech.jts.geom.CoordinateArrays; @@ -108,73 +108,68 @@ public void testSizeDepth() { } public void testNearestNeighbor() { - int n = 1000; // Number of random points to seed + int n = 1000; + int queries = 500; KdTree tree = new KdTree(); Random rand = new Random(1337); - // Seed n random points for (int i = 0; i < n; i++) { - double x = rand.nextDouble() * 100; // Random x between 0 and 100 - double y = rand.nextDouble() * 100; // Random y between 0 and 100 + double x = rand.nextDouble(); + double y = rand.nextDouble(); tree.insert(new Coordinate(x, y)); } - // Test 5 different query points - for (int i = 0; i < 500; i++) { - double queryX = rand.nextDouble() * 100; // Random query x between 0 and 100 - double queryY = rand.nextDouble() * 100; // Random query y between 0 and 100 + for (int i = 0; i < queries; i++) { + double queryX = rand.nextDouble(); + double queryY = rand.nextDouble(); Coordinate query = new Coordinate(queryX, queryY); - // Find nearest neighbor using k-d tree KdNode nearestNode = tree.nearestNeighbor(query); - // Find nearest neighbor using brute-force Coordinate bruteForceNearest = bruteForceNearestNeighbor(tree, query); assertEquals(nearestNode.getCoordinate(), bruteForceNearest); } } - + public void testNearestNeighbors() { - int n = 100; // Number of random points to seed - KdTree tree = new KdTree(); + int n = 2500; + int numTrials = 50; Random rand = new Random(0); - - // Seed n random points - for (int i = 0; i < n; i++) { - double x = rand.nextDouble() * 100; // Random x between 0 and 100 - double y = rand.nextDouble() * 100; // Random y between 0 and 100 - tree.insert(new Coordinate(x, y)); - } - // Query point - Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble()); - int k = 50; + for (int trial = 0; trial < numTrials; trial++) { + KdTree tree = new KdTree(); - // Find k-nearest neighbors using k-d tree - List nearestNodes = tree.nearestNeighbors(query, k); + for (int i = 0; i < n; i++) { + double x = rand.nextDouble(); + double y = rand.nextDouble(); + tree.insert(new Coordinate(x, y)); + } - // Find k-nearest neighbors using brute-force - List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k); + Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble()); + int k = rand.nextInt(n/10); + + List nearestNodes = tree.nearestNeighbors(query, k); - // Verify that both methods return the same results - assertEquals(k, nearestNodes.size()); - for (int i = 0; i < k; i++) { - assertEquals(bruteForceNearest.get(i), nearestNodes.get(i).getCoordinate()); + List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k); + + assertEquals(k, nearestNodes.size()); + for (int i = 0; i < k; i++) { + assertEquals(bruteForceNearest.get(i), nearestNodes.get(i).getCoordinate()); + } } } - public void testPerformance() { - int n = 1_000_000; // Number of random points to seed - int k = 100; // Number of nearest neighbors to find + public void testNearestNeighborsPerformance() { + int n = 1_000_000; + int k = 100; KdTree tree = new KdTree(); Random rand = new Random(1); - // Seed n random points List points = new ArrayList<>(); for (int i = 0; i < n; i++) { - double x = rand.nextDouble(); // Random x between 0 and 100 - double y = rand.nextDouble(); // Random y between 0 and 100 + double x = rand.nextDouble(); + double y = rand.nextDouble(); points.add(new Coordinate(x, y)); } long startTime = System.nanoTime(); @@ -184,7 +179,6 @@ public void testPerformance() { long insertTime = System.nanoTime() - startTime; System.out.println("Time to insert " + n + " points: " + (insertTime / 1_000_000) + " ms"); - // Generate a random query point Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble()); // Time k-NN query using k-d tree @@ -199,22 +193,19 @@ public void testPerformance() { long bruteForceTime = System.nanoTime() - startTime; System.out.println("Time to find " + k + " nearest neighbors using brute-force: " + (bruteForceTime / 1_000_000) + " ms"); - // Verify that both methods return the same results -// assertEquals(k, nearestNodes.size()); for (int i = 0; i < k; i++) { assertEquals(bruteForceNearest.get(i), nearest.get(i).getCoordinate()); } } public void testCollectNodes() { - int n = 1000; // Number of random points to seed + int n = 1000; KdTree tree = new KdTree(); Random rand = new Random(1337); - // Seed n random points for (int i = 0; i < n; i++) { - double x = rand.nextDouble() * 100; // Random x between 0 and 100 - double y = rand.nextDouble() * 100; // Random y between 0 and 100 + double x = rand.nextDouble(); + double y = rand.nextDouble(); tree.insert(new Coordinate(x, y)); } @@ -291,44 +282,18 @@ private Coordinate bruteForceNearestNeighbor(KdTree tree, Coordinate query) { return nearest; } - private Coordinate bruteForceNearestNeighbor(Collection allPoints, Coordinate query) { - Coordinate nearest = null; - double minDistance = Double.POSITIVE_INFINITY; - - for (Coordinate point : allPoints) { - double distance = query.distance(point); - if (distance < minDistance) { - minDistance = distance; - nearest = point; - } - } - - return nearest; - } - private List bruteForceNearestNeighbors(KdTree tree, Coordinate query, int k) { List allPoints = getAllPoints(tree); // Sort all points by distance to the query point allPoints.sort(Comparator.comparingDouble(point -> query.distance(point))); - // Return the first k points + // Return the first k points (ordered closest first) return allPoints.subList(0, Math.min(k, allPoints.size())); } private List getAllPoints(KdTree tree) { - List points = new ArrayList<>(); - collectPoints(tree.getRoot(), points); - return points; - } - - private void collectPoints(KdNode node, List points) { - if (node == null) { - return; - } - points.add(node.getCoordinate()); - collectPoints(node.getLeft(), points); - collectPoints(node.getRight(), points); + return Arrays.stream(KdTree.toCoordinates(tree.getNodes())).collect(Collectors.toList()); } private KdTree build(String wktInput, double tolerance) { From 89ffc84544fb9cb1b5677da7a8623d80fe6e8844 Mon Sep 17 00:00:00 2001 From: Michael Carleton Date: Mon, 27 Jan 2025 14:35:40 +0000 Subject: [PATCH 5/6] assorted kdtree stuff --- .../locationtech/jts/index/kdtree/KdTree.java | 73 +++++++++---------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java index e33e7a640c..04011d7014 100644 --- a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java +++ b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java @@ -26,34 +26,31 @@ import org.locationtech.jts.geom.Envelope; /** - * An implementation of a - * KD-Tree - * over two dimensions (X and Y). - * KD-trees provide fast range searching and fast lookup for point data. - * The tree is built dynamically by inserting points. - * The tree supports queries by location and range, and for point equality. - * For querying, an internal stack is used instead of recursion to avoid overflow. + * A 2D KD-Tree spatial + * index for efficient point query and retrieval. + *

+ * KD-trees provide fast range searching and fast lookup for point data. The + * tree is built dynamically by inserting points. The tree supports queries by + * location and range, and for point equality. For querying, an internal stack + * is used instead of recursion to avoid overflow. *

* This implementation supports detecting and snapping points which are closer - * than a given distance tolerance. - * If the same point (up to tolerance) is inserted - * more than once, it is snapped to the existing node. - * In other words, if a point is inserted which lies - * within the tolerance of a node already in the index, - * it is snapped to that node. - * When an inserted point is snapped to a node then a new node is not created - * but the count of the existing node is incremented. - * If more than one node in the tree is within tolerance of an inserted point, - * the closest and then lowest node is snapped to. + * than a given distance tolerance. If the same point (up to tolerance) is + * inserted more than once, it is snapped to the existing node. In other words, + * if a point is inserted which lies within the tolerance of a node already in + * the index, it is snapped to that node. When an inserted point is snapped to a + * node then a new node is not created but the count of the existing node is + * incremented. If more than one node in the tree is within tolerance of an + * inserted point, the closest and then lowest node is snapped to. *

- * The structure of a KD-Tree depends on the order of insertion of the points. - * A tree may become unbalanced if the inserted points are coherent - * (e.g. monotonic in one or both dimensions). - * A perfectly balanced tree has depth of only log2(N), - * but an unbalanced tree may be much deeper. - * This has a serious impact on query efficiency. - * One solution to this is to randomize the order of points before insertion - * (e.g. by using Fisher-Yates shuffling). + * The structure of a KD-Tree depends on the order of insertion of the points. A + * tree may become unbalanced if the inserted points are coherent (e.g. + * monotonic in one or both dimensions). A perfectly balanced tree has depth of + * only log2(N), but an unbalanced tree may be much deeper. This has a serious + * impact on query efficiency. One solution to this is to randomize the order of + * points before insertion (e.g. by using Fisher-Yates + * shuffling). * * @author David Skea * @author Martin Davis @@ -96,7 +93,8 @@ public static Coordinate[] toCoordinates(Collection kdnodes, boolean inc private KdNode root = null; private long numberOfNodes; - private double tolerance; + private final double tolerance; + private final double toleranceSq; /** * Creates a new instance of a KdTree with a snapping tolerance of 0.0. (I.e. @@ -116,6 +114,7 @@ public KdTree() { */ public KdTree(double tolerance) { this.tolerance = tolerance; + this.toleranceSq = tolerance*tolerance; } /** @@ -205,6 +204,9 @@ public KdNode nearestNeighbor(final Coordinate query) { if (currentDist < bestDistance) { bestNode = currentNode; bestDistance = currentDist; + if (bestDistance == 0) { + return bestNode; // Early termination + } } boolean currentIsXLevel = isXLevel; @@ -279,15 +281,12 @@ public List nearestNeighbors(final Coordinate query, final int n) { while (currentNode != null || !stack.isEmpty()) { if (currentNode != null) { double currentDist = query.distanceSq(currentNode.getCoordinate()); - if (heap.size() < n) { - heap.offer(currentNode); - } else { - double maxDist = query.distanceSq(heap.peek().getCoordinate()); - if (currentDist < maxDist) { - heap.poll(); - heap.offer(currentNode); - } - } + if (heap.size() < n || currentDist < query.distanceSq(heap.peek().getCoordinate())) { + if (heap.size() == n) { + heap.poll(); + } + heap.offer(currentNode); + } boolean currentIsXLevel = isXLevel; double splitValue = currentNode.splitValue(currentIsXLevel); @@ -434,7 +433,7 @@ private KdNode insertExact(Coordinate p, Object data) { * then top-bottom (by Y ordinate) */ while (currentNode != null) { - boolean isInTolerance = p.distance(currentNode.getCoordinate()) <= tolerance; + boolean isInTolerance = p.distanceSq(currentNode.getCoordinate()) <= toleranceSq; // check if point is already in tree (up to tolerance) and if so simply // return existing node @@ -451,10 +450,8 @@ private KdNode insertExact(Coordinate p, Object data) { } leafNode = currentNode; if (isLessThan) { - //System.out.print("L"); currentNode = currentNode.getLeft(); } else { - //System.out.print("R"); currentNode = currentNode.getRight(); } From 4a8b6375c8e32075575d259f725b6556b658a9c7 Mon Sep 17 00:00:00 2001 From: Michael Carleton Date: Wed, 29 Jan 2025 22:51:43 +0000 Subject: [PATCH 6/6] move kdtree nearestNeighbors perf to test.jts,perf --- .../jts/index/kdtree/KdTreeTest.java | 36 ------------ .../test/jts/perf/index/KdtreeStressTest.java | 58 +++++++++++++++++++ 2 files changed, 58 insertions(+), 36 deletions(-) diff --git a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java index ee9a930f46..9eda673f87 100644 --- a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java +++ b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java @@ -160,43 +160,7 @@ public void testNearestNeighbors() { } } - public void testNearestNeighborsPerformance() { - int n = 1_000_000; - int k = 100; - KdTree tree = new KdTree(); - Random rand = new Random(1); - - List points = new ArrayList<>(); - for (int i = 0; i < n; i++) { - double x = rand.nextDouble(); - double y = rand.nextDouble(); - points.add(new Coordinate(x, y)); - } - long startTime = System.nanoTime(); - for (Coordinate coordinate : points) { - tree.insert(coordinate); - } - long insertTime = System.nanoTime() - startTime; - System.out.println("Time to insert " + n + " points: " + (insertTime / 1_000_000) + " ms"); - - Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble()); - - // Time k-NN query using k-d tree - startTime = System.nanoTime(); - List nearest = tree.nearestNeighbors(query, k); - long knnTime = System.nanoTime() - startTime; - System.out.println("Time to find " + k + " nearest neighbors using k-d tree: " + (knnTime / 1_000_000) + " ms"); - // Time k-NN query using brute-force - startTime = System.nanoTime(); - List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k); - long bruteForceTime = System.nanoTime() - startTime; - System.out.println("Time to find " + k + " nearest neighbors using brute-force: " + (bruteForceTime / 1_000_000) + " ms"); - - for (int i = 0; i < k; i++) { - assertEquals(bruteForceNearest.get(i), nearest.get(i).getCoordinate()); - } - } public void testCollectNodes() { int n = 1000; diff --git a/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java b/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java index ed1dbf940b..65dfc0329f 100644 --- a/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java +++ b/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java @@ -1,7 +1,15 @@ package test.jts.perf.index; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + import org.locationtech.jts.geom.Coordinate; import org.locationtech.jts.geom.Envelope; +import org.locationtech.jts.index.kdtree.KdNode; import org.locationtech.jts.index.kdtree.KdTree; /** @@ -35,6 +43,42 @@ private void run() { index.query(env); } System.out.format("Queries complete\n"); + + testNearestNeighborsPerformance(); + } + + private void testNearestNeighborsPerformance() { + int n = 1_000_000; + int k = 1000; + KdTree tree = new KdTree(); + Random rand = new Random(); + + List points = new ArrayList<>(); + for (int i = 0; i < n; i++) { + double x = rand.nextDouble(); + double y = rand.nextDouble(); + points.add(new Coordinate(x, y)); + } + long startTime = System.nanoTime(); + for (Coordinate coordinate : points) { + tree.insert(coordinate); + } + long insertTime = System.nanoTime() - startTime; + System.out.println("Time to insert " + n + " points: " + (insertTime / 1_000_000) + " ms"); + + Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble()); + + // Time k-NN query using k-d tree + startTime = System.nanoTime(); + List nearest = tree.nearestNeighbors(query, k); + long knnTime = System.nanoTime() - startTime; + System.out.println("Time to find " + k + " nearest neighbors using k-d tree: " + (knnTime / 1_000_000) + " ms"); + + // Time k-NN query using brute-force + startTime = System.nanoTime(); + List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k); + long bruteForceTime = System.nanoTime() - startTime; + System.out.println("Time to find " + k + " nearest neighbors using brute-force: " + (bruteForceTime / 1_000_000) + " ms"); } /** @@ -52,4 +96,18 @@ private KdTree createUnbalancedTree(int numPts) { } return index; } + + private List bruteForceNearestNeighbors(KdTree tree, Coordinate query, int k) { + List allPoints = getAllPoints(tree); + + // Sort all points by distance to the query point + allPoints.sort(Comparator.comparingDouble(point -> query.distance(point))); + + // Return the first k points (ordered closest first) + return allPoints.subList(0, Math.min(k, allPoints.size())); + } + + private List getAllPoints(KdTree tree) { + return Arrays.stream(KdTree.toCoordinates(tree.getNodes())).collect(Collectors.toList()); + } }