K-d tree: Difference between revisions

54 bytes added ,  10 years ago
Updated to work in Python2 and Python3, use __slots__ for Orthotope, use pep8 style
(Updated D entry)
(Updated to work in Python2 and Python3, use __slots__ for Orthotope, use pep8 style)
Line 972:
from math import sqrt
from copy import deepcopy
 
 
def sqd(p1, p2):
return sum((c1 - c2) ** 2 for c1, c2 in zip(p1, p2))
 
 
class Kd_nodeKdNode(object):
__slots__ = [("dom_elt", "split", "left", "right"])
 
class Kd_node(object):
__slots__ = ["dom_elt", "split", "left", "right"]
def __init__(self, dom_elt_, split_, left_, right_):
self.dom_elt = dom_elt_
Line 983 ⟶ 986:
self.left = left_
self.right = right_
 
 
class Orthotope(object):
__slots__ = ("min", "max")
 
def __init__(self, mi, ma):
self.min, self.max = mi, ma
 
 
class Kd_treeKdTree(object):
def __init__(self, pts, bounds_):
def nk2(split, exset):
Line 996 ⟶ 1,003:
m = len(exset) // 2
d = exset[m]
while m + 1 < len(exset) and exset[m+1][split] == d[split]:
m += 1
 
s2 = (split + 1) % len(d) # cycle coordinates
return Kd_nodeKdNode(d, split, nk2(s2, exset[:m]),
nk2(s2, exset[m + 1:]))
self.n = nk2(0, pts)
self.bounds = bounds_
 
T3 = namedtuple("T3", "nearest dist_sqd nodes_visited")
 
 
def find_nearest(k, t, p):
def nn(kd, target, hr, max_dist_sqd):
if kd ==is None:
return T3([0.0] * k, float("inf"), 0)
 
Line 1,052 ⟶ 1,060:
 
return nn(t.n, p, t.bounds, float("inf"))
 
 
def show_nearest(k, heading, kd, p):
print (heading + ":")
print ("Point: ", p)
n = find_nearest(k, kd, p)
print ("Nearest neighbor:", n.nearest)
print ("Distance: ", sqrt(n.dist_sqd))
print ("Nodes visited: ", n.nodes_visited, "\n")
 
 
def random_point(k):
return [random() for _ in xrangerange(k)]
 
 
def random_points(k, n):
return [random_point(k) for _ in xrangerange(n)]
 
if __name__ == "__main__":
seed(1)
P = lambda *coords: list(coords)
kd1 = Kd_treeKdTree([P(2, 3), P(5, 4), P(9, 6), P(4, 7), P(8, 1), P(7, 2)],
Orthotope(P(0, 0), P(10, 10)))
show_nearest(2, "Wikipedia example data", kd1, P(9, 2))
Line 1,076 ⟶ 1,087:
N = 400000
t0 = clock()
kd2 = Kd_treeKdTree(random_points(3, N), Orthotope(P(0, 0, 0), P(1, 1, 1)))
t1 = clock()
text = lambda *parts: "".join(map(str, parts))