K-d tree: Difference between revisions

no edit summary
No edit summary
Line 221:
>> Million tree
visited 4271442 nodes for 100000 random findings (42.714420 per lookup)
</pre>
 
=={{header|C++}}==
This code is based on the C version. A significant difference is the use of the standard library function nth_element, which replaces the find_median function in the C version.
 
<lang cpp>
#include <algorithm>
#include <array>
#include <vector>
#include <cmath>
#include <iostream>
#include <random>
 
/**
* Class for representing a point. coordinate_type must be a numeric type.
*/
template<typename coordinate_type, size_t dimensions>
class point
{
public:
point(std::array<coordinate_type, dimensions> c) : coords_(c)
{
}
point(std::initializer_list<coordinate_type> list)
{
size_t n = std::min(dimensions, list.size());
std::copy_n(list.begin(), n, coords_.begin());
}
/**
* Returns the coordinate in the given dimension.
*
* @param index dimension index (zero based)
* @return coordinate in the given dimension
*/
coordinate_type get(size_t index) const
{
return coords_[index];
}
/**
* Returns the distance squared from this point to another
* point.
*
* @param pt another point
* @return distance squared from this point to the other point
*/
double distance(const point& pt) const
{
double dist = 0;
for (size_t i = 0; i < dimensions; ++i)
{
double d = get(i) - pt.get(i);
dist += d * d;
}
return dist;
}
private:
std::array<coordinate_type, dimensions> coords_;
};
 
template<typename coordinate_type, size_t dimensions>
std::ostream& operator<<(std::ostream& out, const point<coordinate_type, dimensions>& pt)
{
out << '(';
for (size_t i = 0; i < dimensions; ++i)
{
if (i > 0)
out << ", ";
out << pt.get(i);
}
out << ')';
return out;
}
 
/**
* C++ k-d tree implementation, based on the C version at rosettacode.org.
*/
template<typename coordinate_type, size_t dimensions>
class kdtree
{
public:
typedef point<coordinate_type, dimensions> point_type;
private:
struct node
{
node(const point_type& pt) : point_(pt), left_(nullptr), right_(nullptr)
{
}
coordinate_type get(size_t index) const
{
return point_.get(index);
}
double distance(const point_type& pt) const
{
return point_.distance(pt);
}
point_type point_;
node* left_;
node* right_;
};
node* root_;
node* best_;
double best_dist_;
size_t visited_;
std::vector<node> nodes_;
 
struct node_cmp
{
node_cmp(size_t index) : index_(index)
{
}
bool operator()(const node& n1, const node& n2) const
{
return n1.point_.get(index_) < n2.point_.get(index_);
}
size_t index_;
};
 
node* make_tree(size_t begin, size_t end, size_t index)
{
if (end <= begin)
return nullptr;
size_t n = begin + (end - begin)/2;
std::nth_element(&nodes_[begin], &nodes_[n], &nodes_[end], node_cmp(index));
index = (index + 1) % dimensions;
nodes_[n].left_ = make_tree(begin, n, index);
nodes_[n].right_ = make_tree(n + 1, end, index);
return &nodes_[n];
}
 
void nearest(node* root, const point_type& point, size_t index)
{
if (root == nullptr)
return;
++visited_;
double d = root->distance(point);
if (best_ == nullptr || d < best_dist_)
{
best_dist_ = d;
best_ = root;
}
if (best_dist_ == 0)
return;
double dx = root->get(index) - point.get(index);
index = (index + 1) % dimensions;
nearest(dx > 0 ? root->left_ : root->right_, point, index);
if (dx * dx >= best_dist_)
return;
nearest(dx > 0 ? root->right_ : root->left_, point, index);
}
public:
/**
* Constructor taking a pair of iterators. Adds each
* point in the range [begin, end) to the tree.
*
* @param begin start of range
* @param end end of range
*/
template<typename iterator>
kdtree(iterator begin, iterator end)
{
nodes_.reserve(std::distance(begin, end));
for (auto i = begin; i != end; ++i)
nodes_.emplace_back(*i);
root_ = make_tree(0, nodes_.size(), 0);
}
/**
* Constructor taking a function object that generates
* points. The function object will be called n times
* to populate the tree.
*
* @param f function that returns a point
* @param n number of points to add
*/
template<typename func>
kdtree(func&& f, size_t n)
{
nodes_.reserve(n);
for (size_t i = 0; i < n; ++i)
nodes_.emplace_back(f());
root_ = make_tree(0, nodes_.size(), 0);
}
 
/**
* Returns true if the tree is empty, false otherwise.
*/
bool empty() const
{
return nodes_.empty();
}
 
/**
* Returns the number of nodes visited by the last call
* to nearest().
*/
size_t visited() const
{
return visited_;
}
 
/**
* Returns the distance between the input point and return value
* from the last call to nearest().
*/
double distance() const
{
return std::sqrt(best_dist_);
}
 
/**
* Finds the nearest point in the tree to the given point.
* It is not valid to call this function if the tree is empty.
*
* @param pt a point
* @param the nearest point in the tree to the given point
*/
const point_type& nearest(const point_type& pt)
{
if (root_ == nullptr)
throw std::logic_error("tree is empty");
best_ = nullptr;
visited_ = 0;
best_dist_ = 0;
nearest(root_, pt, 0);
return best_->point_;
}
};
 
void test_wikipedia()
{
typedef point<int, 2> point2d;
typedef kdtree<int, 2> tree2d;
 
point2d points[] = { { 2, 3 }, { 5, 4 }, { 9, 6 }, { 4, 7 }, { 8, 1 }, { 7, 2 } };
 
tree2d tree(std::begin(points), std::end(points));
point2d n = tree.nearest({ 9, 2 });
 
std::cout << "Wikipedia example data:\n";
std::cout << "nearest point: " << n << '\n';
std::cout << "distance: " << tree.distance() << '\n';
std::cout << "nodes visited: " << tree.visited() << '\n';
}
 
typedef point<double, 3> point3d;
typedef kdtree<double, 3> tree3d;
 
struct random_point_generator
{
random_point_generator(double min, double max)
: engine_(std::random_device()()), distribution_(min, max)
{
}
 
point3d operator()()
{
double x = distribution_(engine_);
double y = distribution_(engine_);
double z = distribution_(engine_);
return point3d({x, y, z});
}
 
std::mt19937 engine_;
std::uniform_real_distribution<double> distribution_;
};
 
void test_random(size_t count)
{
random_point_generator rpg(0, 1);
tree3d tree(rpg, count);
point3d pt(rpg());
point3d n = tree.nearest(pt);
 
std::cout << "Random data (" << count << " points):\n";
std::cout << "point: " << pt << '\n';
std::cout << "nearest point: " << n << '\n';
std::cout << "distance: " << tree.distance() << '\n';
std::cout << "nodes visited: " << tree.visited() << '\n';
}
 
int main()
{
try
{
test_wikipedia();
std::cout << '\n';
test_random(1000);
std::cout << '\n';
test_random(1000000);
}
catch (const std::exception& e)
{
std::cerr << e.what() << '\n';
}
 
return 0;
}
</lang>
 
{{out}}
<pre>
Wikipedia example data:
nearest point: (8, 1)
distance: 1.41421
nodes visited: 3
 
Random data (1000 points):
point: (0.740311, 0.290258, 0.832057)
nearest point: (0.761247, 0.294663, 0.83404)
distance: 0.0214867
nodes visited: 15
 
Random data (1000000 points):
point: (0.646712, 0.555327, 0.596551)
nearest point: (0.642795, 0.552513, 0.599618)
distance: 0.00571496
nodes visited: 46
</pre>
 
1,777

edits