diff --git a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java
index ad40b75f25..04011d7014 100644
--- a/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java
+++ b/modules/core/src/main/java/org/locationtech/jts/index/kdtree/KdTree.java
@@ -15,43 +15,42 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
+import java.util.PriorityQueue;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.CoordinateList;
import org.locationtech.jts.geom.Envelope;
/**
- * An implementation of a
- * KD-Tree
- * over two dimensions (X and Y).
- * KD-trees provide fast range searching and fast lookup for point data.
- * The tree is built dynamically by inserting points.
- * The tree supports queries by range and for point equality.
- * For querying an internal stack is used instead of recursion to avoid overflow.
+ * A 2D KD-Tree spatial
+ * index for efficient point query and retrieval.
+ *
+ * KD-trees provide fast range searching and fast lookup for point data. The
+ * tree is built dynamically by inserting points. The tree supports queries by
+ * location and range, and for point equality. For querying, an internal stack
+ * is used instead of recursion to avoid overflow.
*
* This implementation supports detecting and snapping points which are closer
- * than a given distance tolerance.
- * If the same point (up to tolerance) is inserted
- * more than once, it is snapped to the existing node.
- * In other words, if a point is inserted which lies
- * within the tolerance of a node already in the index,
- * it is snapped to that node.
- * When an inserted point is snapped to a node then a new node is not created
- * but the count of the existing node is incremented.
- * If more than one node in the tree is within tolerance of an inserted point,
- * the closest and then lowest node is snapped to.
+ * than a given distance tolerance. If the same point (up to tolerance) is
+ * inserted more than once, it is snapped to the existing node. In other words,
+ * if a point is inserted which lies within the tolerance of a node already in
+ * the index, it is snapped to that node. When an inserted point is snapped to a
+ * node then a new node is not created but the count of the existing node is
+ * incremented. If more than one node in the tree is within tolerance of an
+ * inserted point, the closest and then lowest node is snapped to.
*
- * The structure of a KD-Tree depends on the order of insertion of the points.
- * A tree may become unbalanced if the inserted points are coherent
- * (e.g. monotonic in one or both dimensions).
- * A perfectly balanced tree has depth of only log2(N),
- * but an unbalanced tree may be much deeper.
- * This has a serious impact on query efficiency.
- * One solution to this is to randomize the order of points before insertion
- * (e.g. by using Fisher-Yates shuffling).
+ * The structure of a KD-Tree depends on the order of insertion of the points. A
+ * tree may become unbalanced if the inserted points are coherent (e.g.
+ * monotonic in one or both dimensions). A perfectly balanced tree has depth of
+ * only log2(N), but an unbalanced tree may be much deeper. This has a serious
+ * impact on query efficiency. One solution to this is to randomize the order of
+ * points before insertion (e.g. by using Fisher-Yates
+ * shuffling).
*
* @author David Skea
* @author Martin Davis
@@ -65,7 +64,7 @@ public class KdTree {
* a collection of nodes
* @return an array of the coordinates represented by the nodes
*/
- public static Coordinate[] toCoordinates(Collection kdnodes) {
+ public static Coordinate[] toCoordinates(Collection kdnodes) {
return toCoordinates(kdnodes, false);
}
@@ -80,9 +79,9 @@ public static Coordinate[] toCoordinates(Collection kdnodes) {
* be included multiple times
* @return an array of the coordinates represented by the nodes
*/
- public static Coordinate[] toCoordinates(Collection kdnodes, boolean includeRepeated) {
+ public static Coordinate[] toCoordinates(Collection kdnodes, boolean includeRepeated) {
CoordinateList coord = new CoordinateList();
- for (Iterator it = kdnodes.iterator(); it.hasNext();) {
+ for (Iterator it = kdnodes.iterator(); it.hasNext();) {
KdNode node = (KdNode) it.next();
int count = includeRepeated ? node.getCount() : 1;
for (int i = 0; i < count; i++) {
@@ -94,7 +93,8 @@ public static Coordinate[] toCoordinates(Collection kdnodes, boolean includeRepe
private KdNode root = null;
private long numberOfNodes;
- private double tolerance;
+ private final double tolerance;
+ private final double toleranceSq;
/**
* Creates a new instance of a KdTree with a snapping tolerance of 0.0. (I.e.
@@ -114,6 +114,7 @@ public KdTree() {
*/
public KdTree(double tolerance) {
this.tolerance = tolerance;
+ this.toleranceSq = tolerance*tolerance;
}
/**
@@ -179,6 +180,180 @@ public KdNode insert(Coordinate p, Object data) {
return insertExact(p, data);
}
+
+ /**
+ * Finds the nearest node in the tree to the given query point.
+ *
+ * @param query the query point
+ * @return the nearest node, or null if the tree is empty
+ */
+ public KdNode nearestNeighbor(final Coordinate query) {
+ if (root == null) {
+ return null;
+ }
+
+ KdNode bestNode = null;
+ double bestDistance = Double.POSITIVE_INFINITY;
+ Deque stack = new ArrayDeque<>();
+ KdNode currentNode = root;
+ boolean isXLevel = true;
+
+ while (currentNode != null || !stack.isEmpty()) {
+ if (currentNode != null) {
+ double currentDist = query.distanceSq(currentNode.getCoordinate());
+ if (currentDist < bestDistance) {
+ bestNode = currentNode;
+ bestDistance = currentDist;
+ if (bestDistance == 0) {
+ return bestNode; // Early termination
+ }
+ }
+
+ boolean currentIsXLevel = isXLevel;
+ double splitValue = currentNode.splitValue(currentIsXLevel);
+ KdNode nextNode;
+ KdNode otherNode;
+
+ if (currentIsXLevel) {
+ if (query.x < splitValue) {
+ nextNode = currentNode.getLeft();
+ otherNode = currentNode.getRight();
+ } else {
+ nextNode = currentNode.getRight();
+ otherNode = currentNode.getLeft();
+ }
+ } else {
+ if (query.y < splitValue) {
+ nextNode = currentNode.getLeft();
+ otherNode = currentNode.getRight();
+ } else {
+ nextNode = currentNode.getRight();
+ otherNode = currentNode.getLeft();
+ }
+ }
+
+ stack.push(new NNStackFrame(otherNode, currentIsXLevel, splitValue));
+ currentNode = nextNode;
+ isXLevel = !currentIsXLevel;
+ } else {
+ NNStackFrame frame = stack.pop();
+ KdNode otherNode = frame.node;
+ boolean parentSplitAxis = frame.parentSplitAxis;
+ double parentSplitValue = frame.parentSplitValue;
+
+ double diff = parentSplitAxis
+ ? query.x - parentSplitValue
+ : query.y - parentSplitValue;
+ double distanceToSplitSq = diff * diff;
+
+ if (distanceToSplitSq < bestDistance) {
+ currentNode = otherNode;
+ isXLevel = !parentSplitAxis;
+ } else {
+ currentNode = null;
+ }
+ }
+ }
+
+ return bestNode;
+ }
+
+ /**
+ * Finds the nearest N nodes in the tree to the given query point.
+ *
+ * @param query the query point
+ * @param n the number of nearest nodes to find
+ * @return a list of the nearest nodes, sorted by distance (closest first), or an empty list if the tree is empty.
+ */
+ public List nearestNeighbors(final Coordinate query, final int n) {
+ if (root == null || n <= 0) {
+ return Collections.emptyList();
+ }
+
+ PriorityQueue heap = new PriorityQueue<>(n, (n1, n2) ->
+ Double.compare(query.distanceSq(n2.getCoordinate()), query.distanceSq(n1.getCoordinate()))
+ );
+
+ Deque stack = new ArrayDeque<>();
+ KdNode currentNode = root;
+ boolean isXLevel = true;
+
+ while (currentNode != null || !stack.isEmpty()) {
+ if (currentNode != null) {
+ double currentDist = query.distanceSq(currentNode.getCoordinate());
+ if (heap.size() < n || currentDist < query.distanceSq(heap.peek().getCoordinate())) {
+ if (heap.size() == n) {
+ heap.poll();
+ }
+ heap.offer(currentNode);
+ }
+
+ boolean currentIsXLevel = isXLevel;
+ double splitValue = currentNode.splitValue(currentIsXLevel);
+ KdNode nextNode;
+ KdNode otherNode;
+
+ if (currentIsXLevel) {
+ if (query.x < splitValue) {
+ nextNode = currentNode.getLeft();
+ otherNode = currentNode.getRight();
+ } else {
+ nextNode = currentNode.getRight();
+ otherNode = currentNode.getLeft();
+ }
+ } else {
+ if (query.y < splitValue) {
+ nextNode = currentNode.getLeft();
+ otherNode = currentNode.getRight();
+ } else {
+ nextNode = currentNode.getRight();
+ otherNode = currentNode.getLeft();
+ }
+ }
+
+ stack.push(new NNStackFrame(otherNode, currentIsXLevel, splitValue));
+ currentNode = nextNode;
+ isXLevel = !currentIsXLevel;
+ } else {
+ NNStackFrame frame = stack.pop();
+ KdNode otherNode = frame.node;
+ boolean parentSplitAxis = frame.parentSplitAxis;
+ double parentSplitValue = frame.parentSplitValue;
+
+ double diff = parentSplitAxis
+ ? query.x - parentSplitValue
+ : query.y - parentSplitValue;
+ double distanceToSplitSq = diff * diff;
+
+ double currentMaxDist = heap.isEmpty() ? Double.POSITIVE_INFINITY : query.distanceSq(heap.peek().getCoordinate());
+
+ if (distanceToSplitSq < currentMaxDist || heap.size() < n) {
+ currentNode = otherNode;
+ isXLevel = !parentSplitAxis;
+ } else {
+ currentNode = null;
+ }
+ }
+ }
+
+ List result = new ArrayList<>(heap);
+ Collections.sort(result, (n1, n2) ->
+ Double.compare(query.distanceSq(n1.getCoordinate()), query.distanceSq(n2.getCoordinate()))
+ );
+ return result;
+ }
+
+ private static class NNStackFrame {
+ KdNode node;
+ boolean parentSplitAxis;
+ double parentSplitValue;
+
+ NNStackFrame(KdNode node, boolean parentSplitAxis, double parentSplitValue) {
+ this.node = node;
+ this.parentSplitAxis = parentSplitAxis;
+ this.parentSplitValue = parentSplitValue;
+ }
+ }
/**
* Finds the node in the tree which is the best match for a point
@@ -189,10 +364,9 @@ public KdNode insert(Coordinate p, Object data) {
* existing node.
*
* @param p the point being inserted
- * @return the best matching node
- * @return null if no match was found
+ * @return the best matching node. null if no match was found.
*/
- private KdNode findBestMatchNode(Coordinate p) {
+ public KdNode findBestMatchNode(Coordinate p) {
BestMatchVisitor visitor = new BestMatchVisitor(p, tolerance);
query(visitor.queryEnvelope(), visitor);
return visitor.getNode();
@@ -259,7 +433,7 @@ private KdNode insertExact(Coordinate p, Object data) {
* then top-bottom (by Y ordinate)
*/
while (currentNode != null) {
- boolean isInTolerance = p.distance(currentNode.getCoordinate()) <= tolerance;
+ boolean isInTolerance = p.distanceSq(currentNode.getCoordinate()) <= toleranceSq;
// check if point is already in tree (up to tolerance) and if so simply
// return existing node
@@ -276,10 +450,8 @@ private KdNode insertExact(Coordinate p, Object data) {
}
leafNode = currentNode;
if (isLessThan) {
- //System.out.print("L");
currentNode = currentNode.getLeft();
} else {
- //System.out.print("R");
currentNode = currentNode.getRight();
}
@@ -377,8 +549,8 @@ public boolean isXLevel() {
* @param queryEnv the range rectangle to query
* @return a list of the KdNodes found
*/
- public List query(Envelope queryEnv) {
- final List result = new ArrayList();
+ public List query(Envelope queryEnv) {
+ final List result = new ArrayList();
query(queryEnv, result);
return result;
}
@@ -391,7 +563,7 @@ public List query(Envelope queryEnv) {
* @param result
* a list to accumulate the result nodes into
*/
- public void query(Envelope queryEnv, final List result) {
+ public void query(Envelope queryEnv, final List result) {
query(queryEnv, new KdNodeVisitor() {
public void visit(KdNode node) {
@@ -426,6 +598,35 @@ public KdNode query(Coordinate queryPt) {
//-- point not found
return null;
}
+
+ /**
+ * Performs an in-order traversal of the tree, collecting and returning all
+ * nodes that have been inserted.
+ *
+ * @return A list containing all nodes in the KdTree. Returns an empty list if
+ * the tree is empty.
+ */
+ public List getNodes() {
+ List nodeList = new ArrayList<>();
+ if (root == null) {
+ return nodeList; // empty list for empty tree
+ }
+
+ Deque stack = new ArrayDeque<>();
+ KdNode currentNode = root;
+
+ while (currentNode != null || !stack.isEmpty()) {
+ if (currentNode != null) {
+ stack.push(currentNode);
+ currentNode = currentNode.getLeft();
+ } else {
+ currentNode = stack.pop();
+ nodeList.add(currentNode);
+ currentNode = currentNode.getRight();
+ }
+ }
+ return nodeList;
+ }
/**
* Computes the depth of the tree.
diff --git a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java
index 79a46b1f53..9eda673f87 100644
--- a/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java
+++ b/modules/core/src/test/java/org/locationtech/jts/index/kdtree/KdTreeTest.java
@@ -12,8 +12,12 @@
package org.locationtech.jts.index.kdtree;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Comparator;
import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.CoordinateArrays;
@@ -103,6 +107,75 @@ public void testSizeDepth() {
assertTrue( depth <= size );
}
+ public void testNearestNeighbor() {
+ int n = 1000;
+ int queries = 500;
+ KdTree tree = new KdTree();
+ Random rand = new Random(1337);
+
+ for (int i = 0; i < n; i++) {
+ double x = rand.nextDouble();
+ double y = rand.nextDouble();
+ tree.insert(new Coordinate(x, y));
+ }
+
+ for (int i = 0; i < queries; i++) {
+ double queryX = rand.nextDouble();
+ double queryY = rand.nextDouble();
+ Coordinate query = new Coordinate(queryX, queryY);
+
+ KdNode nearestNode = tree.nearestNeighbor(query);
+
+ Coordinate bruteForceNearest = bruteForceNearestNeighbor(tree, query);
+
+ assertEquals(nearestNode.getCoordinate(), bruteForceNearest);
+ }
+ }
+
+ public void testNearestNeighbors() {
+ int n = 2500;
+ int numTrials = 50;
+ Random rand = new Random(0);
+
+ for (int trial = 0; trial < numTrials; trial++) {
+ KdTree tree = new KdTree();
+
+ for (int i = 0; i < n; i++) {
+ double x = rand.nextDouble();
+ double y = rand.nextDouble();
+ tree.insert(new Coordinate(x, y));
+ }
+
+ Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble());
+ int k = rand.nextInt(n/10);
+
+ List nearestNodes = tree.nearestNeighbors(query, k);
+
+ List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k);
+
+ assertEquals(k, nearestNodes.size());
+ for (int i = 0; i < k; i++) {
+ assertEquals(bruteForceNearest.get(i), nearestNodes.get(i).getCoordinate());
+ }
+ }
+ }
+
+
+
+ public void testCollectNodes() {
+ int n = 1000;
+ KdTree tree = new KdTree();
+ Random rand = new Random(1337);
+
+ for (int i = 0; i < n; i++) {
+ double x = rand.nextDouble();
+ double y = rand.nextDouble();
+ tree.insert(new Coordinate(x, y));
+ }
+
+ assertEquals(n, tree.getNodes().size());
+ }
+
private void testQuery(String wktInput, double tolerance,
Envelope queryEnv, String wktExpected) {
KdTree index = build(wktInput, tolerance);
@@ -155,6 +228,37 @@ private void testQuery(KdTree index,
assertEquals("Point query not found", node.getCoordinate(), p);
}
}
+
+ // Helper method to find the nearest neighbor using brute-force
+ private Coordinate bruteForceNearestNeighbor(KdTree tree, Coordinate query) {
+ List allPoints = getAllPoints(tree);
+ Coordinate nearest = null;
+ double minDistance = Double.POSITIVE_INFINITY;
+
+ for (Coordinate point : allPoints) {
+ double distance = query.distance(point);
+ if (distance < minDistance) {
+ minDistance = distance;
+ nearest = point;
+ }
+ }
+
+ return nearest;
+ }
+
+ private List bruteForceNearestNeighbors(KdTree tree, Coordinate query, int k) {
+ List allPoints = getAllPoints(tree);
+
+ // Sort all points by distance to the query point
+ allPoints.sort(Comparator.comparingDouble(point -> query.distance(point)));
+
+ // Return the first k points (ordered closest first)
+ return allPoints.subList(0, Math.min(k, allPoints.size()));
+ }
+
+ private List getAllPoints(KdTree tree) {
+ return Arrays.stream(KdTree.toCoordinates(tree.getNodes())).collect(Collectors.toList());
+ }
private KdTree build(String wktInput, double tolerance) {
final KdTree index = new KdTree(tolerance);
diff --git a/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java b/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java
index ed1dbf940b..65dfc0329f 100644
--- a/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java
+++ b/modules/core/src/test/java/test/jts/perf/index/KdtreeStressTest.java
@@ -1,7 +1,15 @@
package test.jts.perf.index;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.Envelope;
+import org.locationtech.jts.index.kdtree.KdNode;
import org.locationtech.jts.index.kdtree.KdTree;
/**
@@ -35,6 +43,42 @@ private void run() {
index.query(env);
}
System.out.format("Queries complete\n");
+
+ testNearestNeighborsPerformance();
+ }
+
+ private void testNearestNeighborsPerformance() {
+ int n = 1_000_000;
+ int k = 1000;
+ KdTree tree = new KdTree();
+ Random rand = new Random();
+
+ List points = new ArrayList<>();
+ for (int i = 0; i < n; i++) {
+ double x = rand.nextDouble();
+ double y = rand.nextDouble();
+ points.add(new Coordinate(x, y));
+ }
+ long startTime = System.nanoTime();
+ for (Coordinate coordinate : points) {
+ tree.insert(coordinate);
+ }
+ long insertTime = System.nanoTime() - startTime;
+ System.out.println("Time to insert " + n + " points: " + (insertTime / 1_000_000) + " ms");
+
+ Coordinate query = new Coordinate(rand.nextDouble(), rand.nextDouble());
+
+ // Time k-NN query using k-d tree
+ startTime = System.nanoTime();
+ List nearest = tree.nearestNeighbors(query, k);
+ long knnTime = System.nanoTime() - startTime;
+ System.out.println("Time to find " + k + " nearest neighbors using k-d tree: " + (knnTime / 1_000_000) + " ms");
+
+ // Time k-NN query using brute-force
+ startTime = System.nanoTime();
+ List bruteForceNearest = bruteForceNearestNeighbors(tree, query, k);
+ long bruteForceTime = System.nanoTime() - startTime;
+ System.out.println("Time to find " + k + " nearest neighbors using brute-force: " + (bruteForceTime / 1_000_000) + " ms");
}
/**
@@ -52,4 +96,18 @@ private KdTree createUnbalancedTree(int numPts) {
}
return index;
}
+
+ private List bruteForceNearestNeighbors(KdTree tree, Coordinate query, int k) {
+ List allPoints = getAllPoints(tree);
+
+ // Sort all points by distance to the query point
+ allPoints.sort(Comparator.comparingDouble(point -> query.distance(point)));
+
+ // Return the first k points (ordered closest first)
+ return allPoints.subList(0, Math.min(k, allPoints.size()));
+ }
+
+ private List getAllPoints(KdTree tree) {
+ return Arrays.stream(KdTree.toCoordinates(tree.getNodes())).collect(Collectors.toList());
+ }
}