diff options
author | Tavian Barnes <tavianator@tavianator.com> | 2020-05-03 10:55:16 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-03 10:55:16 -0400 |
commit | ce2904b4840611f769b92b55bf6d9b5afe84d3d7 (patch) | |
tree | a133319a302f95edf7a7a261262a8f24473bd21c /src/metric/kd.rs | |
parent | d95e93bf70f3351e6fd489284794ef7909fd94ce (diff) | |
parent | 2984e8f93fe88d0ee7eb3c0561dcd2da44807429 (diff) | |
download | kd-forest-ce2904b4840611f769b92b55bf6d9b5afe84d3d7.tar.xz |
Merge pull request #1 from tavianator/rust
Rewrite in rust
Diffstat (limited to 'src/metric/kd.rs')
-rw-r--r-- | src/metric/kd.rs | 226 |
1 files changed, 226 insertions, 0 deletions
diff --git a/src/metric/kd.rs b/src/metric/kd.rs new file mode 100644 index 0000000..2caf4a3 --- /dev/null +++ b/src/metric/kd.rs @@ -0,0 +1,226 @@ +//! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree). + +use super::{Metric, NearestNeighbors, Neighborhood}; + +use ordered_float::OrderedFloat; + +use std::iter::FromIterator; + +/// A point in Cartesian space. +pub trait Cartesian: Metric<[f64]> { + /// Returns the number of dimensions necessary to describe this point. + fn dimensions(&self) -> usize; + + /// Returns the value of the `i`th coordinate of this point (`i < self.dimensions()`). + fn coordinate(&self, i: usize) -> f64; +} + +/// Blanket [Cartesian] implementation for references. +impl<'a, T: Cartesian> Cartesian for &'a T { + fn dimensions(&self) -> usize { + (*self).dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + (*self).coordinate(i) + } +} + +/// Blanket [Metric<[f64]>](Metric) implementation for [Cartesian] references. +impl<'a, T: Cartesian> Metric<[f64]> for &'a T { + type Distance = T::Distance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + (*self).distance(other) + } +} + +/// Standard cartesian space. +impl Cartesian for [f64] { + fn dimensions(&self) -> usize { + self.len() + } + + fn coordinate(&self, i: usize) -> f64 { + self[i] + } +} + +/// Marker trait for cartesian metric spaces. +pub trait CartesianMetric<T: ?Sized = Self>: + Cartesian + Metric<T, Distance = <Self as Metric<[f64]>>::Distance> +{ +} + +/// Blanket [CartesianMetric] implementation for cartesian spaces with compatible metric distance +/// types. +impl<T, U> CartesianMetric<T> for U +where + T: ?Sized, + U: ?Sized + Cartesian + Metric<T, Distance = <U as Metric<[f64]>>::Distance>, +{ +} + +/// A node in a k-d tree. +#[derive(Debug)] +struct KdNode<T> { + /// The value stored in this node. + item: T, + /// The size of the left subtree. + left_len: usize, +} + +impl<T: Cartesian> KdNode<T> { + /// Create a new KdNode. + fn new(item: T) -> Self { + Self { item, left_len: 0 } + } + + /// Build a k-d tree recursively. + fn build(slice: &mut [KdNode<T>], i: usize) { + if slice.is_empty() { + return; + } + + slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i))); + + let mid = slice.len() / 2; + slice.swap(0, mid); + + let (node, children) = slice.split_first_mut().unwrap(); + let (left, right) = children.split_at_mut(mid); + node.left_len = left.len(); + + let j = (i + 1) % node.item.dimensions(); + Self::build(left, j); + Self::build(right, j); + } + + /// Recursively search for nearest neighbors. + fn recurse<'a, U, N>( + slice: &'a [KdNode<T>], + i: usize, + closest: &mut [f64], + neighborhood: &mut N, + ) where + T: 'a, + U: CartesianMetric<&'a T>, + N: Neighborhood<&'a T, U>, + { + let (node, children) = slice.split_first().unwrap(); + neighborhood.consider(&node.item); + + let target = neighborhood.target(); + let ti = target.coordinate(i); + let ni = node.item.coordinate(i); + let j = (i + 1) % node.item.dimensions(); + + let (left, right) = children.split_at(node.left_len); + let (near, far) = if ti <= ni { + (left, right) + } else { + (right, left) + }; + + if !near.is_empty() { + Self::recurse(near, j, closest, neighborhood); + } + + if !far.is_empty() { + let saved = closest[i]; + closest[i] = ni; + if neighborhood.contains_distance(target.distance(closest)) { + Self::recurse(far, j, closest, neighborhood); + } + closest[i] = saved; + } + } +} + +/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree). +#[derive(Debug)] +pub struct KdTree<T>(Vec<KdNode<T>>); + +impl<T: Cartesian> FromIterator<T> for KdTree<T> { + /// Create a new k-d tree from a set of points. + fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { + let mut nodes: Vec<_> = items.into_iter().map(KdNode::new).collect(); + KdNode::build(nodes.as_mut_slice(), 0); + Self(nodes) + } +} + +impl<T, U> NearestNeighbors<T, U> for KdTree<T> +where + T: Cartesian, + U: CartesianMetric<T>, +{ + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + if !self.0.is_empty() { + let target = neighborhood.target(); + let dims = target.dimensions(); + let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect(); + + KdNode::recurse(&self.0, 0, &mut closest, &mut neighborhood); + } + + neighborhood + } +} + +/// An iterator that the moves values out of a k-d tree. +#[derive(Debug)] +pub struct IntoIter<T>(std::vec::IntoIter<KdNode<T>>); + +impl<T> Iterator for IntoIter<T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + self.0.next().map(|n| n.item) + } +} + +impl<T> IntoIterator for KdTree<T> { + type Item = T; + type IntoIter = IntoIter<T>; + + fn into_iter(self) -> Self::IntoIter { + IntoIter(self.0.into_iter()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::tests::{test_nearest_neighbors, Point}; + use crate::metric::SquaredDistance; + + impl Metric<[f64]> for Point { + type Distance = SquaredDistance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + self.0.distance(other) + } + } + + impl Cartesian for Point { + fn dimensions(&self) -> usize { + self.0.dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + self.0.coordinate(i) + } + } + + #[test] + fn test_kd_tree() { + test_nearest_neighbors(KdTree::from_iter); + } +} |