diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/lib.rs | 1 | ||||
-rw-r--r-- | src/vp.rs | 372 |
2 files changed, 373 insertions, 0 deletions
@@ -6,6 +6,7 @@ pub mod coords; pub mod distance; pub mod euclid; pub mod exhaustive; +pub mod vp; mod util; diff --git a/src/vp.rs b/src/vp.rs new file mode 100644 index 0000000..120e13b --- /dev/null +++ b/src/vp.rs @@ -0,0 +1,372 @@ +//! Vantage-point trees. + +use crate::distance::{DistanceValue, Metric, Proximity}; +use crate::util::Ordered; +use crate::{ExactNeighbors, NearestNeighbors, Neighborhood}; + +use num_traits::zero; + +use std::fmt::{self, Debug, Formatter}; +use std::iter::{Extend, FromIterator}; +use std::ops::Deref; + +/// A node in a VP tree. +#[derive(Debug)] +struct VpNode<T, R = DistanceValue<T>> { + /// The vantage point itself. + item: T, + /// The radius of this node. + radius: R, + /// The subtree inside the radius, if any. + inside: Option<Box<Self>>, + /// The subtree outside the radius, if any. + outside: Option<Box<Self>>, +} + +impl<T: Proximity> VpNode<T> { + /// Create a new VpNode. + fn new(item: T) -> Self { + Self { + item, + radius: zero(), + inside: None, + outside: None, + } + } + + /// Create a balanced tree. + fn balanced<I: IntoIterator<Item = T>>(items: I) -> Option<Self> { + let mut nodes: Vec<_> = items + .into_iter() + .map(Self::new) + .map(Box::new) + .map(Some) + .collect(); + + Self::balanced_recursive(&mut nodes) + .map(|node| *node) + } + + /// Create a balanced subtree. + fn balanced_recursive(nodes: &mut [Option<Box<Self>>]) -> Option<Box<Self>> { + if let Some((node, children)) = nodes.split_first_mut() { + let mut node = node.take().unwrap(); + children.sort_by_cached_key(|x| Ordered::new(Self::box_distance(&node, x))); + + let (inside, outside) = children.split_at_mut(children.len() / 2); + if let Some(last) = inside.last() { + node.radius = Self::box_distance(&node, last).into(); + } + + node.inside = Self::balanced_recursive(inside); + node.outside = Self::balanced_recursive(outside); + + Some(node) + } else { + None + } + } + + /// Get the distance between to boxed nodes. + fn box_distance(node: &Box<Self>, child: &Option<Box<Self>>) -> T::Distance { + node.item.distance(&child.as_ref().unwrap().item) + } + + /// Push a new item into this subtree. + fn push(&mut self, item: T) + where + // https://github.com/rust-lang/rust/issues/72582 + T::Distance: PartialOrd + PartialOrd<DistanceValue<T>>, + { + match (&mut self.inside, &mut self.outside) { + (None, None) => { + self.outside = Some(Box::new(Self::new(item))); + } + (Some(inside), Some(outside)) => { + if self.item.distance(&item) <= self.radius { + inside.push(item); + } else { + outside.push(item); + } + } + _ => { + let node = Box::new(Self::new(item)); + let other = self.inside.take().xor(self.outside.take()).unwrap(); + + let r1 = self.item.distance(&node.item); + let r2 = self.item.distance(&other.item); + + if r1 <= r2 { + self.radius = r2.into(); + self.inside = Some(node); + self.outside = Some(other); + } else { + self.radius = r1.into(); + self.inside = Some(other); + self.outside = Some(node); + } + } + } + } +} + +trait VpSearch<K, V, N>: Copy +where + K: Proximity<V, Distance = V::Distance>, + V: Proximity, + N: Neighborhood<K, V>, +{ + /// Get the vantage point of this node. + fn item(self) -> V; + + /// Get the radius of this node. + fn radius(self) -> DistanceValue<V>; + + /// Get the inside subtree. + fn inside(self) -> Option<Self>; + + /// Get the outside subtree. + fn outside(self) -> Option<Self>; + + /// Recursively search for nearest neighbors. + fn search(self, neighborhood: &mut N) { + let distance = neighborhood.consider(self.item()).into(); + + if distance <= self.radius() { + self.search_inside(distance, neighborhood); + self.search_outside(distance, neighborhood); + } else { + self.search_outside(distance, neighborhood); + self.search_inside(distance, neighborhood); + } + } + + /// Search the inside subtree. + fn search_inside(self, distance: DistanceValue<V>, neighborhood: &mut N) { + if let Some(inside) = self.inside() { + if neighborhood.contains(distance - self.radius()) { + inside.search(neighborhood); + } + } + } + + /// Search the outside subtree. + fn search_outside(self, distance: DistanceValue<V>, neighborhood: &mut N) { + if let Some(outside) = self.outside() { + if neighborhood.contains(self.radius() - distance) { + outside.search(neighborhood); + } + } + } +} + +impl<'a, K, V, N> VpSearch<K, &'a V, N> for &'a VpNode<V> +where + K: Proximity<&'a V, Distance = V::Distance>, + V: Proximity, + N: Neighborhood<K, &'a V>, +{ + fn item(self) -> &'a V { + &self.item + } + + fn radius(self) -> DistanceValue<V> { + self.radius + } + + fn inside(self) -> Option<Self> { + self.inside.as_ref().map(Box::deref) + } + + fn outside(self) -> Option<Self> { + self.outside.as_ref().map(Box::deref) + } +} + +/// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree). +pub struct VpTree<T: Proximity> { + root: Option<VpNode<T>>, +} + +impl<T: Proximity> VpTree<T> { + /// Create an empty tree. + pub fn new() -> Self { + Self { + root: None, + } + } + + /// Create a balanced tree out of a sequence of items. + pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self { + Self { + root: VpNode::balanced(items), + } + } + + /// Rebalance this VP tree. + pub fn balance(&mut self) { + let mut nodes = Vec::new(); + if let Some(root) = self.root.take() { + nodes.push(Some(Box::new(root))); + } + + let mut i = 0; + while i < nodes.len() { + let node = nodes[i].as_mut().unwrap(); + let inside = node.inside.take(); + let outside = node.outside.take(); + if inside.is_some() { + nodes.push(inside); + } + if outside.is_some() { + nodes.push(outside); + } + + i += 1; + } + + self.root = VpNode::balanced_recursive(&mut nodes) + .map(|node| *node); + } + + /// Push a new item into the tree. + /// + /// Inserting elements individually tends to unbalance the tree. Use [VpTree::balanced] if + /// possible to create a balanced tree from a batch of items. + pub fn push(&mut self, item: T) { + if let Some(root) = &mut self.root { + root.push(item); + } else { + self.root = Some(VpNode::new(item)); + } + } +} + +// Can't derive(Debug) due to https://github.com/rust-lang/rust/issues/26925 +impl<T> Debug for VpTree<T> +where + T: Proximity + Debug, + DistanceValue<T>: Debug, +{ + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("VpTree") + .field("root", &self.root) + .finish() + } +} + +impl<T: Proximity> Extend<T> for VpTree<T> { + fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) { + if self.root.is_some() { + for item in items { + self.push(item); + } + } else { + self.root = VpNode::balanced(items); + } + } +} + +impl<T: Proximity> FromIterator<T> for VpTree<T> { + fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { + Self::balanced(items) + } +} + +/// An iterator that moves values out of a VP tree. +pub struct IntoIter<T: Proximity> { + stack: Vec<VpNode<T>>, +} + +impl<T: Proximity> IntoIter<T> { + fn new(node: Option<VpNode<T>>) -> Self { + Self { + stack: node.into_iter().collect(), + } + } +} + +impl<T> Debug for IntoIter<T> +where + T: Proximity + Debug, + DistanceValue<T>: Debug, +{ + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("IntoIter") + .field("stack", &self.stack) + .finish() + } +} + +impl<T: Proximity> Iterator for IntoIter<T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + self.stack.pop().map(|node| { + if let Some(inside) = node.inside { + self.stack.push(*inside); + } + if let Some(outside) = node.outside { + self.stack.push(*outside); + } + node.item + }) + } +} + +impl<T: Proximity> IntoIterator for VpTree<T> { + type Item = T; + type IntoIter = IntoIter<T>; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self.root) + } +} + +impl<K, V> NearestNeighbors<K, V> for VpTree<V> +where + K: Proximity<V, Distance = V::Distance>, + V: Proximity, +{ + fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N + where + K: 'k, + V: 'v, + N: Neighborhood<&'k K, &'v V>, + { + if let Some(root) = &self.root { + root.search(&mut neighborhood); + } + neighborhood + } +} + +impl<K, V> ExactNeighbors<K, V> for VpTree<V> +where + K: Metric<V, Distance = V::Distance>, + V: Metric, +{} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::test_nearest_neighbors; + + #[test] + fn test_vp_tree() { + test_nearest_neighbors(VpTree::from_iter); + } + + #[test] + fn test_unbalanced_vp_tree() { + test_nearest_neighbors(|points| { + let mut tree = VpTree::new(); + for point in points { + tree.push(point); + } + tree + }); + } +} + |