K-d tree: Difference between revisions
Content added Content deleted
m (→{{header|J}}) |
|||
Line 1,380: | Line 1,380: | ||
Visits: 39 |
Visits: 39 |
||
</lang> |
</lang> |
||
=={{header|Scala}}== |
|||
This example works for sequences of Int, Double, etc, so it is non-minimal due to its type-safe Numeric parameterisation. |
|||
<lang Scala>object KDTree { |
|||
import Numeric._ |
|||
// Task 1A. Build tree of KDNodes |
|||
// Translated from Wikipedia |
|||
def apply[T](points: Seq[Seq[T]], depth: Int = 0)(implicit num: Numeric[T]): Option[KDNode[T]] = { |
|||
val dim = points.headOption.map(_.size) getOrElse 0 |
|||
if (points.isEmpty || dim < 1) None |
|||
else { |
|||
val axis = depth % dim |
|||
val sorted = points.sortBy(_(axis)) |
|||
val median = sorted(sorted.size / 2)(axis) |
|||
val (left, right) = sorted.partition(v => num.lt(v(axis), median)) |
|||
Some(KDNode(right.head, apply(left, depth + 1), apply(right.tail, depth + 1), axis)) |
|||
} |
|||
} |
|||
// Task 1B. KDNode class to contain Node data. |
|||
// Contains a method to find the nearest node in this subtree, translated from Wikipedia |
|||
case class KDNode[T](value: Seq[T], left: Option[KDNode[T]], right: Option[KDNode[T]], axis: Int)(implicit num: Numeric[T]) { |
|||
def nearest(to: Seq[T]): Nearest[T] = { |
|||
val default = Nearest(value, to, Set(this)) |
|||
compare(to, value) match { |
|||
case 0 => default // exact match |
|||
case t => |
|||
lazy val bestL = left.map(_ nearest to).getOrElse(default) |
|||
lazy val bestR = right.map(_ nearest to).getOrElse(default) |
|||
val branch1 = if (t < 0) bestL else bestR |
|||
val best = if (num.lt(branch1.distsq, default.distsq)) branch1 else default |
|||
val splitDist = num.minus(to(axis), value(axis)) |
|||
if (num.lt(num.times(splitDist, splitDist), best.distsq)) { |
|||
val branch2 = if (t < 0) bestR else bestL |
|||
val visited = branch2.visited ++ best.visited + this |
|||
if (num.lt(branch2.distsq, best.distsq)) |
|||
branch2.copy(visited = visited) |
|||
else best.copy(visited = visited) |
|||
} else best.copy(visited = best.visited + this) |
|||
} |
|||
} |
|||
} |
|||
// Numeric utilities |
|||
def distsq[T](a: Seq[T], b: Seq[T])(implicit num: Numeric[T]) = |
|||
a.zip(b).map(c => num.times(num.minus(c._1, c._2), num.minus(c._1, c._2))).sum |
|||
def compare[T](a: Seq[T], b: Seq[T])(implicit num: Numeric[T]): Int = |
|||
a.zip(b).find(c => num.compare(c._1, c._2) != 0).map(c => num.compare(c._1, c._2)).getOrElse(0) |
|||
// Something to keep track of nodes visited as per task |
|||
case class Nearest[T](value: Seq[T], to: Seq[T], visited: Set[KDNode[T]] = Set[KDNode[T]]())(implicit num: Numeric[T]) { |
|||
lazy val distsq = KDTree.distsq(value, to) |
|||
override def toString = f"Searched for=${to} found=${value} distance=${math.sqrt(num.toDouble(distsq))}%.4f visited=${visited.size}" |
|||
} |
|||
}</lang> |
|||
Task test: |
|||
<lang Scala>object KDTreeTest extends App { |
|||
def test[T](haystack: Seq[Seq[T]], needles: Seq[T]*)(implicit num: Numeric[T]) = { |
|||
println |
|||
val tree = KDTree(haystack) |
|||
if (haystack.size < 20) tree.foreach(println) |
|||
for (kd <- tree; needle <- needles; nearest = kd nearest needle) { |
|||
println(nearest) |
|||
// Brute force proof |
|||
val better = haystack |
|||
.map(KDTree.Nearest(_, needle)) |
|||
.filter(n => num.lt(n.distsq, nearest.distsq)) |
|||
.sortBy(_.distsq) |
|||
assert(better.isEmpty, s"Found ${better.size} closer than ${nearest.value} e.g. ${better.head}") |
|||
} |
|||
} |
|||
// Results 1 |
|||
val wikitest = List(List(2,3), List(5,4), List(9,6), List(4,7), List(8,1), List(7,2)) |
|||
test(wikitest, List(9,2)) |
|||
// Results 2 (1000 points uniformly distributed in 3-d cube coordinates, sides 2 to 20) |
|||
val uniform = for(x <- 1 to 10; y <- 1 to 10; z <- 1 to 10) yield List(x*2, y*2, z*2) |
|||
assume(uniform.size == 1000) |
|||
test(uniform, List(0, 0, 0), List(2, 2, 20), List(9, 10, 11)) |
|||
// Results 3 (1000 points randomly distributed in 3-d cube coordinates, sides -1.0 to 1.0) |
|||
scala.util.Random.setSeed(0) |
|||
def random(n: Int) = (1 to n).map(_ => (scala.util.Random.nextDouble - 0.5)* 2) |
|||
test((1 to 1000).map(_ => random(3)), random(3)) |
|||
// Results 4 (27 points uniformly distributed in 3-d cube coordinates, sides 3...9) |
|||
val small = for(x <- 1 to 3; y <- 1 to 3; z <- 1 to 3) yield List(x*3, y*3, z*3) |
|||
assume(small.size == 27) |
|||
test(small, List(0, 0, 0), List(4, 5, 6)) |
|||
}</lang> |
|||
{{out}} |
|||
<pre>KDNode(List(7, 2),Some(KDNode(List(5, 4),Some(KDNode(List(2, 3),None,None,0)),Some(KDNode(List(4, 7),None,None,0)),1)),Some(KDNode(List(9, 6),Some(KDNode(List(8, 1),None,None,0)),None,1)),0) |
|||
Searched for=List(9, 2) found=List(8, 1) distance=1.4142 visited=3 |
|||
Searched for=List(0, 0, 0) found=List(2, 2, 2) distance=3.4641 visited=10 |
|||
Searched for=List(2, 2, 20) found=List(2, 2, 20) distance=0.0000 visited=9 |
|||
Searched for=List(9, 10, 11) found=List(8, 10, 12) distance=1.4142 visited=134 |
|||
Searched for=Vector(0.19269603520919643, -0.25958512078298535, -0.2572864045762784) found=Vector(0.07811099409527977, -0.2477618820196814, -0.20252227622550611) distance=0.1275 visited=25 |
|||
Searched for=List(0, 0, 0) found=List(3, 3, 3) distance=5.1962 visited=4 |
|||
Searched for=List(4, 5, 6) found=List(3, 6, 6) distance=1.4142 visited=6</pre> |
|||
=={{header|Tcl}}== |
=={{header|Tcl}}== |