K-d tree: Difference between revisions

Content added Content deleted
Line 1,380: Line 1,380:
Visits: 39
Visits: 39
</lang>
</lang>
=={{header|Scala}}==
This example works for sequences of Int, Double, etc, so it is non-minimal due to its type-safe Numeric parameterisation.
<lang Scala>object KDTree {
import Numeric._

// Task 1A. Build tree of KDNodes
// Translated from Wikipedia
def apply[T](points: Seq[Seq[T]], depth: Int = 0)(implicit num: Numeric[T]): Option[KDNode[T]] = {
val dim = points.headOption.map(_.size) getOrElse 0
if (points.isEmpty || dim < 1) None
else {
val axis = depth % dim
val sorted = points.sortBy(_(axis))
val median = sorted(sorted.size / 2)(axis)
val (left, right) = sorted.partition(v => num.lt(v(axis), median))
Some(KDNode(right.head, apply(left, depth + 1), apply(right.tail, depth + 1), axis))
}
}

// Task 1B. KDNode class to contain Node data.
// Contains a method to find the nearest node in this subtree, translated from Wikipedia
case class KDNode[T](value: Seq[T], left: Option[KDNode[T]], right: Option[KDNode[T]], axis: Int)(implicit num: Numeric[T]) {
def nearest(to: Seq[T]): Nearest[T] = {
val default = Nearest(value, to, Set(this))
compare(to, value) match {
case 0 => default // exact match
case t =>
lazy val bestL = left.map(_ nearest to).getOrElse(default)
lazy val bestR = right.map(_ nearest to).getOrElse(default)
val branch1 = if (t < 0) bestL else bestR
val best = if (num.lt(branch1.distsq, default.distsq)) branch1 else default
val splitDist = num.minus(to(axis), value(axis))
if (num.lt(num.times(splitDist, splitDist), best.distsq)) {
val branch2 = if (t < 0) bestR else bestL
val visited = branch2.visited ++ best.visited + this
if (num.lt(branch2.distsq, best.distsq))
branch2.copy(visited = visited)
else best.copy(visited = visited)
} else best.copy(visited = best.visited + this)
}
}
}

// Numeric utilities
def distsq[T](a: Seq[T], b: Seq[T])(implicit num: Numeric[T]) =
a.zip(b).map(c => num.times(num.minus(c._1, c._2), num.minus(c._1, c._2))).sum
def compare[T](a: Seq[T], b: Seq[T])(implicit num: Numeric[T]): Int =
a.zip(b).find(c => num.compare(c._1, c._2) != 0).map(c => num.compare(c._1, c._2)).getOrElse(0)

// Something to keep track of nodes visited as per task
case class Nearest[T](value: Seq[T], to: Seq[T], visited: Set[KDNode[T]] = Set[KDNode[T]]())(implicit num: Numeric[T]) {
lazy val distsq = KDTree.distsq(value, to)
override def toString = f"Searched for=${to} found=${value} distance=${math.sqrt(num.toDouble(distsq))}%.4f visited=${visited.size}"
}
}</lang>
Task test:
<lang Scala>object KDTreeTest extends App {
def test[T](haystack: Seq[Seq[T]], needles: Seq[T]*)(implicit num: Numeric[T]) = {
println
val tree = KDTree(haystack)
if (haystack.size < 20) tree.foreach(println)
for (kd <- tree; needle <- needles; nearest = kd nearest needle) {
println(nearest)
// Brute force proof
val better = haystack
.map(KDTree.Nearest(_, needle))
.filter(n => num.lt(n.distsq, nearest.distsq))
.sortBy(_.distsq)
assert(better.isEmpty, s"Found ${better.size} closer than ${nearest.value} e.g. ${better.head}")
}
}

// Results 1
val wikitest = List(List(2,3), List(5,4), List(9,6), List(4,7), List(8,1), List(7,2))
test(wikitest, List(9,2))

// Results 2 (1000 points uniformly distributed in 3-d cube coordinates, sides 2 to 20)
val uniform = for(x <- 1 to 10; y <- 1 to 10; z <- 1 to 10) yield List(x*2, y*2, z*2)
assume(uniform.size == 1000)
test(uniform, List(0, 0, 0), List(2, 2, 20), List(9, 10, 11))

// Results 3 (1000 points randomly distributed in 3-d cube coordinates, sides -1.0 to 1.0)
scala.util.Random.setSeed(0)
def random(n: Int) = (1 to n).map(_ => (scala.util.Random.nextDouble - 0.5)* 2)
test((1 to 1000).map(_ => random(3)), random(3))

// Results 4 (27 points uniformly distributed in 3-d cube coordinates, sides 3...9)
val small = for(x <- 1 to 3; y <- 1 to 3; z <- 1 to 3) yield List(x*3, y*3, z*3)
assume(small.size == 27)
test(small, List(0, 0, 0), List(4, 5, 6))
}</lang>
{{out}}
<pre>KDNode(List(7, 2),Some(KDNode(List(5, 4),Some(KDNode(List(2, 3),None,None,0)),Some(KDNode(List(4, 7),None,None,0)),1)),Some(KDNode(List(9, 6),Some(KDNode(List(8, 1),None,None,0)),None,1)),0)
Searched for=List(9, 2) found=List(8, 1) distance=1.4142 visited=3

Searched for=List(0, 0, 0) found=List(2, 2, 2) distance=3.4641 visited=10
Searched for=List(2, 2, 20) found=List(2, 2, 20) distance=0.0000 visited=9
Searched for=List(9, 10, 11) found=List(8, 10, 12) distance=1.4142 visited=134

Searched for=Vector(0.19269603520919643, -0.25958512078298535, -0.2572864045762784) found=Vector(0.07811099409527977, -0.2477618820196814, -0.20252227622550611) distance=0.1275 visited=25

Searched for=List(0, 0, 0) found=List(3, 3, 3) distance=5.1962 visited=4
Searched for=List(4, 5, 6) found=List(3, 6, 6) distance=1.4142 visited=6</pre>


=={{header|Tcl}}==
=={{header|Tcl}}==