K-d tree: Difference between revisions

Content added Content deleted
(Updated output first D entry)
(Updated first D entry)
Line 217: Line 217:


import std.typecons, std.math, std.algorithm, std.random, std.range,
import std.typecons, std.math, std.algorithm, std.random, std.range,
std.traits;
std.traits, core.memory;


/// k-dimensional point.
/// k-dimensional point.
struct Point(size_t k, F) if (isFloatingPoint!F) {
struct Point(size_t k, F) if (isFloatingPoint!F) {
F[k] data;
F[k] data;

// alias data this; // kills DMD std.algorithm.swap inlining
// alias data this; // kills DMD std.algorithm.swap inlining
F opIndex(in size_t i) const pure nothrow { return data[i]; }
F opIndex(in size_t i) const pure nothrow { return data[i]; }

void opIndexAssign(in F x, in size_t i) pure nothrow {
void opIndexAssign(in F x, in size_t i) pure nothrow {
data[i] = x;
data[i] = x;
}
}

enum size_t length = k;
enum size_t length = k;


/// Square of the euclidean distance.
/// Square of the euclidean distance.
double sqd(in ref Point!(k, F) q) pure nothrow {
double sqd(in ref Point!(k, F) q) const pure nothrow {
double sum = 0;
double sum = 0;
foreach (dim, pCoord; data)
foreach (immutable dim, immutable pCoord; data)
sum += (pCoord - q[dim]) ^^ 2;
sum += (pCoord - q[dim]) ^^ 2;
return sum;
return sum;
Line 244: Line 247:
Point!(k, F) domElt;
Point!(k, F) domElt;
immutable int split;
immutable int split;
KdNode!(k, F)* left, right;
typeof(this)* left, right;
}
}


Line 262: Line 265:
static KdNode!(k, F)* nk2(size_t split)(Point!(k, F)[] exset)
static KdNode!(k, F)* nk2(size_t split)(Point!(k, F)[] exset)
pure {
pure {
if (exset.empty) return null;
if (exset.empty)
return null;
if (exset.length == 1)
if (exset.length == 1)
return new KdNode!(k, F)(exset[0], split, null, null);
return new KdNode!(k, F)(exset[0], split, null, null);

// Pivot choosing procedure. We find median, then find
// Pivot choosing procedure. We find median, then find
// largest index of points with median value. This
// largest index of points with median value. This
Line 280: Line 285:
nk2!nextSplit(exset[m + 1 .. $]));
nk2!nextSplit(exset[m + 1 .. $]));
}
}

this.n = nk2!0(pts);
this.n = nk2!0(pts);
this.bounds = bounds_;
this.bounds = bounds_;
Line 296: Line 302:
// counting the number nodes visited.
// counting the number nodes visited.
static Tuple!(Point!(k, F), "nearest",
static Tuple!(Point!(k, F), "nearest",
double, "distSqd",
F, "distSqd",
int, "nodesVisited")
int, "nodesVisited")
nn(KdNode!(k, F)* kd, in Point!(k, F) target,
nn(KdNode!(k, F)* kd, in Point!(k, F) target,
Orthotope!(k, F) hr, double maxDistSqd) pure nothrow {
Orthotope!(k, F) hr, F maxDistSqd) pure nothrow {
if (kd == null)
if (kd == null)
return typeof(return)(Point!(k, F)(), double.infinity, 0);
return typeof(return)(Point!(k, F)(), F.infinity, 0);


int nodesVisited = 1;
int nodesVisited = 1;
Line 314: Line 320:
Orthotope!(k, F) nearerHr, furtherHr;
Orthotope!(k, F) nearerHr, furtherHr;
if (target[s] <= pivot[s]) {
if (target[s] <= pivot[s]) {
//nearerKd, nearerHr = kd.left, leftHr
//nearerKd, nearerHr = kd.left, leftHr;
//furtherKd, furtherHr = kd.right, rightHr
//furtherKd, furtherHr = kd.right, rightHr;
nearerKd = kd.left;
nearerKd = kd.left;
nearerHr = leftHr;
nearerHr = leftHr;
Line 321: Line 327:
furtherHr = rightHr;
furtherHr = rightHr;
} else {
} else {
//nearerKd, nearerHr = kd.right, rightHr
//nearerKd, nearerHr = kd.right, rightHr;
//furtherKd, furtherHr = kd.left, leftHr
//furtherKd, furtherHr = kd.left, leftHr;
nearerKd = kd.right;
nearerKd = kd.right;
nearerHr = rightHr;
nearerHr = rightHr;
Line 356: Line 362:
}
}


return nn(t.n, p, t.bounds, double.infinity);
return nn(t.n, p, t.bounds, F.infinity);
}
}


void showNearest(size_t k, F)(in string heading, KdTree!(k, F) kd,
void showNearest(size_t k, F)(in string heading, KdTree!(k, F) kd,
in Point!(k, F) p) {
in Point!(k, F) p) {
import std.stdio;
import std.stdio: writeln;
writeln(heading, ":");
writeln(heading, ":");
writeln("Point: ", p);
writeln("Point: ", p);
Line 373: Line 379:
static Point!(k, F) randomPoint(size_t k, F)() {
static Point!(k, F) randomPoint(size_t k, F)() {
typeof(return) result;
typeof(return) result;
foreach (i; 0 .. k)
foreach (immutable i; 0 .. k)
result[i] = uniform(cast(F)0, cast(F)1);
result[i] = uniform(cast(F)0, cast(F)1);
return result;
return result;
}
}


static Point!(k, F)[] randomPoints(size_t k, F)(in int n) {
static Point!(k, F)[] randomPoints(size_t k, F)(in size_t n) {
return iota(n).map!(_ => randomPoint!(k, F)())().array();
return n.iota.map!(_ => randomPoint!(k, F)).array;
}
}


Line 385: Line 391:
rndGen.seed(1); // For repeatable outputs.
rndGen.seed(1); // For repeatable outputs.


alias TypeTuple!(2, double) D2;
alias D2 = TypeTuple!(2, double);
alias Point!D2 P;
alias P = Point!D2;
auto kd1 = KdTree!D2([P([2, 3]), P([5, 4]), P([9, 6]),
auto kd1 = KdTree!D2([P([2, 3]), P([5, 4]), P([9, 6]),
P([4, 7]), P([8, 1]), P([7, 2])],
P([4, 7]), P([8, 1]), P([7, 2])],
Line 393: Line 399:


enum int N = 400_000;
enum int N = 400_000;
alias TypeTuple!(3, float) F3;
alias F3 = TypeTuple!(3, float);
alias Point!F3 Q;
alias Q = Point!F3;
StopWatch sw;
StopWatch sw;
sw.start();
GC.disable;
sw.start;
auto kd2 = KdTree!F3(randomPoints!F3(N),
auto kd2 = KdTree!F3(randomPoints!F3(N),
Orthotope!F3(Q([0, 0, 0]), Q([1, 1, 1])));
Orthotope!F3(Q([0, 0, 0]), Q([1, 1, 1])));
sw.stop();
sw.stop;
GC.enable;
showNearest(text("k-d tree with ", N,
showNearest(text("k-d tree with ", N,
" random 3D ", F3[1].stringof,
" random 3D ", F3[1].stringof,
" points (construction time: ",
" points (construction time: ",
sw.peek().msecs, "ms)"), kd2, randomPoint!F3());
sw.peek.msecs, " ms)"), kd2, randomPoint!F3);


sw.reset();
sw.reset;
sw.start();
sw.start;
enum int M = 10_000;
enum int M = 10_000;
size_t visited = 0;
size_t visited = 0;
foreach (_; 0 .. M) {
foreach (immutable _; 0 .. M) {
immutable n = kd2.findNearest(randomPoint!F3());
immutable n = kd2.findNearest(randomPoint!F3);
visited += n.nodesVisited;
visited += n.nodesVisited;
}
}
sw.stop();
sw.stop;

writefln("Visited an average of %0.2f nodes on %d searches " ~
writefln("Visited an average of %0.2f nodes on %d searches " ~
"in %dms.", visited / cast(double)M, M, sw.peek().msecs);
"in %d ms.", visited / cast(double)M, M, sw.peek.msecs);
}</lang>
}</lang>
{{out}}
<pre>Wikipedia example data:
Point: const(Point!(2,double))([9, 2])
Nearest neighbor: immutable(Point!(2,double))([8, 1])
Distance: 1.41421
Nodes visited: 3

k-d tree with 400000 random 3D float points (construction time: 1062ms):
Point: const(Point!(3,float))([0.22012, 0.984514, 0.698782])
Nearest neighbor: immutable(Point!(3,float))([0.225766, 0.978981, 0.69885])
Distance: 0.00790531
Nodes visited: 54

Visited an average of 43.10 nodes on 10000 searches in 213ms.</pre>

{{out|Output, using the ldc2 compiler}}
{{out|Output, using the ldc2 compiler}}
<pre>Wikipedia example data:
<pre>Wikipedia example data:
Line 439: Line 433:
Nodes visited: 3
Nodes visited: 3


k-d tree with 400000 random 3D float points (construction time: 306ms):
k-d tree with 400000 random 3D float points (construction time: 250 ms):
Point: const(Point!(3, float))([0.22012, 0.984514, 0.698782])
Point: const(Point!(3, float))([0.22012, 0.984514, 0.698782])
Nearest neighbor: immutable(Point!(3, float))([0.225766, 0.978981, 0.69885])
Nearest neighbor: immutable(Point!(3, float))([0.225766, 0.978981, 0.69885])
Line 445: Line 439:
Nodes visited: 54
Nodes visited: 54


Visited an average of 43.10 nodes on 10000 searches in 47ms.</pre>
Visited an average of 43.10 nodes on 10000 searches in 33 ms.</pre>


===Faster Alternative Version===
===Faster Alternative Version===