K-d tree: Difference between revisions

7,406 bytes added ,  4 years ago
Added Java solution
m (C++ bug fix)
(Added Java solution)
Line 1,623:
 
See also: [[wp:KISS_principle]]
 
=={{header|Java}}==
Based on the C++ solution.
File KdTree.java:
<lang java>import java.util.*;
 
public class KdTree {
private int dimensions_;
private Node root_ = null;
private Node best_ = null;
double bestDistance_ = 0;
int visited_ = 0;
public KdTree(int dimensions, List<Node> nodes) {
dimensions_ = dimensions;
root_ = makeTree(nodes, 0, nodes.size(), 0);
}
public Node findNearest(Node target) {
if (root_ == null)
throw new IllegalStateException("Tree is empty!");
best_ = null;
visited_ = 0;
bestDistance_ = 0;
nearest(root_, target, 0);
return best_;
}
public int visited() {
return visited_;
}
public double distance() {
return Math.sqrt(bestDistance_);
}
private void nearest(Node root, Node target, int index) {
if (root == null)
return;
++visited_;
double d = root.distance(target);
if (best_ == null || d < bestDistance_) {
bestDistance_ = d;
best_ = root;
}
if (bestDistance_ == 0)
return;
double dx = root.get(index) - target.get(index);
index = (index + 1) % dimensions_;
nearest(dx > 0 ? root.left_ : root.right_, target, index);
if (dx * dx >= bestDistance_)
return;
nearest(dx > 0 ? root.right_ : root.left_, target, index);
}
private Node makeTree(List<Node> nodes, int begin, int end, int index) {
if (end <= begin)
return null;
int n = begin + (end - begin)/2;
Node node = QuickSelect.select(nodes, begin, end - 1, n, new NodeComparator(index));
index = (index + 1) % dimensions_;
node.left_ = makeTree(nodes, begin, n, index);
node.right_ = makeTree(nodes, n + 1, end, index);
return node;
}
private static class NodeComparator implements Comparator<Node> {
private int index_;
 
private NodeComparator(int index) {
index_ = index;
}
public int compare(Node n1, Node n2) {
return Double.compare(n1.get(index_), n2.get(index_));
}
}
public static class Node {
private double[] coords_;
private Node left_ = null;
private Node right_ = null;
 
public Node(double[] coords) {
coords_ = coords;
}
public Node(double x, double y) {
this(new double[]{x, y});
}
public Node(double x, double y, double z) {
this(new double[]{x, y, z});
}
double get(int index) {
return coords_[index];
}
double distance(Node node) {
double dist = 0;
for (int i = 0; i < coords_.length; ++i) {
double d = coords_[i] - node.coords_[i];
dist += d * d;
}
return dist;
}
public String toString() {
StringBuilder s = new StringBuilder("(");
for (int i = 0; i < coords_.length; ++i) {
if (i > 0)
s.append(", ");
s.append(coords_[i]);
}
s.append(')');
return s.toString();
}
}
}</lang>
File QuickSelect.java:
<lang java>import java.util.*;
 
//
// Java implementation of quickselect algorithm.
// See https://en.wikipedia.org/wiki/Quickselect
//
public class QuickSelect {
private static Random random = new Random();
 
public static <T> T select(List<T> list, int n, Comparator<? super T> cmp) {
return select(list, 0, list.size() - 1, n, cmp);
}
public static <T> T select(List<T> list, int left, int right, int n, Comparator<? super T> cmp) {
for (;;) {
if (left == right)
return list.get(left);
int pivot = pivotIndex(left, right);
pivot = partition(list, left, right, pivot, cmp);
if (n == pivot)
return list.get(n);
else if (n < pivot)
right = pivot - 1;
else
left = pivot + 1;
}
}
private static <T> int partition(List<T> list, int left, int right, int pivot, Comparator<? super T> cmp) {
T pivotValue = list.get(pivot);
swap(list, pivot, right);
int store = left;
for (int i = left; i < right; ++i) {
if (cmp.compare(list.get(i), pivotValue) < 0) {
swap(list, store, i);
++store;
}
}
swap(list, right, store);
return store;
}
private static <T> void swap(List<T> list, int i, int j) {
T value = list.get(i);
list.set(i, list.get(j));
list.set(j, value);
}
 
private static int pivotIndex(int left, int right) {
return left + random.nextInt(right - left + 1);
}
}</lang>
File KdTreeTest.java:
<lang java>import java.util.*;
 
public class KdTreeTest {
public static void main(String[] args) {
testWikipedia();
System.out.println();
testRandom(1000);
System.out.println();
testRandom(1000000);
}
private static void testWikipedia() {
double[][] coords = {
{ 2, 3 }, { 5, 4 }, { 9, 6 }, { 4, 7 }, { 8, 1 }, { 7, 2 }
};
List<KdTree.Node> nodes = new ArrayList<>();
for (int i = 0; i < coords.length; ++i)
nodes.add(new KdTree.Node(coords[i]));
KdTree tree = new KdTree(2, nodes);
KdTree.Node nearest = tree.findNearest(new KdTree.Node(9, 2));
System.out.println("Wikipedia example data:");
System.out.println("nearest point: " + nearest);
System.out.println("distance: " + tree.distance());
System.out.println("nodes visited: " + tree.visited());
}
 
private static KdTree.Node randomPoint(Random random) {
double x = random.nextDouble();
double y = random.nextDouble();
double z = random.nextDouble();
return new KdTree.Node(x, y, x);
}
 
private static void testRandom(int points) {
Random random = new Random();
List<KdTree.Node> nodes = new ArrayList<>();
for (int i = 0; i < points; ++i)
nodes.add(randomPoint(random));
KdTree tree = new KdTree(3, nodes);
KdTree.Node target = randomPoint(random);
KdTree.Node nearest = tree.findNearest(target);
System.out.println("Random data (" + points + " points):");
System.out.println("target: " + target);
System.out.println("nearest point: " + nearest);
System.out.println("distance: " + tree.distance());
System.out.println("nodes visited: " + tree.visited());
}
}</lang>
 
{{out}}
<pre>
Wikipedia example data:
nearest point: (8.0, 1.0)
distance: 1.4142135623730951
nodes visited: 3
 
Random data (1000 points):
target: (0.2918786351725754, 0.17598290673343409, 0.2918786351725754)
nearest point: (0.3174649034203002, 0.15199166781223472, 0.3174649034203002)
distance: 0.04341536353254609
nodes visited: 29
 
Random data (1000000 points):
target: (0.9674373365069223, 0.846272104653006, 0.9674373365069223)
nearest point: (0.9672541953409804, 0.845661810787955, 0.9672541953409804)
distance: 6.629781105315114E-4
nodes visited: 42
</pre>
 
=={{header|Julia}}==
1,777

edits