diff options
-rw-r--r-- | benches/benches.rs | 8 | ||||
-rw-r--r-- | src/vp.rs | 183 |
2 files changed, 189 insertions, 2 deletions
diff --git a/benches/benches.rs b/benches/benches.rs index 8791845..a8676ba 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -2,7 +2,7 @@ use acap::euclid::Euclidean; use acap::exhaustive::ExhaustiveSearch; -use acap::vp::VpTree; +use acap::vp::{FlatVpTree, VpTree}; use acap::NearestNeighbors; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -36,6 +36,7 @@ fn bench_from_iter(c: &mut Criterion) { let mut group = c.benchmark_group("from_iter"); group.bench_function("ExhaustiveSearch", |b| b.iter(|| ExhaustiveSearch::from_iter(points.clone()))); group.bench_function("VpTree", |b| b.iter(|| VpTree::from_iter(points.clone()))); + group.bench_function("FlatVpTree", |b| b.iter(|| FlatVpTree::from_iter(points.clone()))); group.finish(); } @@ -45,25 +46,30 @@ fn bench_nearest_neighbors(c: &mut Criterion) { let exhaustive = ExhaustiveSearch::from_iter(points.clone()); let vp_tree = VpTree::from_iter(points.clone()); + let flat_vp_tree = FlatVpTree::from_iter(points.clone()); let mut nearest = c.benchmark_group("NearestNeighbors::nearest"); nearest.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.nearest(&target))); nearest.bench_function("VpTree", |b| b.iter(|| vp_tree.nearest(&target))); + nearest.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.nearest(&target))); nearest.finish(); let mut nearest_within = c.benchmark_group("NearestNeighbors::nearest_within"); nearest_within.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.nearest_within(&target, 0.1))); nearest_within.bench_function("VpTree", |b| b.iter(|| vp_tree.nearest_within(&target, 0.1))); + nearest_within.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.nearest_within(&target, 0.1))); nearest_within.finish(); let mut k_nearest = c.benchmark_group("NearestNeighbors::k_nearest"); k_nearest.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.k_nearest(&target, 3))); k_nearest.bench_function("VpTree", |b| b.iter(|| vp_tree.k_nearest(&target, 3))); + k_nearest.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.k_nearest(&target, 3))); k_nearest.finish(); let mut k_nearest_within = c.benchmark_group("NearestNeighbors::k_nearest_within"); k_nearest_within.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.k_nearest_within(&target, 3, 0.1))); k_nearest_within.bench_function("VpTree", |b| b.iter(|| vp_tree.k_nearest_within(&target, 3, 0.1))); + k_nearest_within.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.k_nearest_within(&target, 3, 0.1))); k_nearest_within.finish(); } @@ -347,6 +347,183 @@ where V: Metric, {} +/// A node in a flat VP tree. +#[derive(Debug)] +struct FlatVpNode<T, R = DistanceValue<T>> { + /// The vantage point itself. + item: T, + /// The radius of this node. + radius: R, + /// The size of the inside subtree. + inside_len: usize, +} + +impl<T: Proximity> FlatVpNode<T> { + /// Create a new FlatVpNode. + fn new(item: T) -> Self { + Self { + item, + radius: zero(), + inside_len: 0, + } + } + + /// Create a balanced tree. + fn balanced<I: IntoIterator<Item = T>>(items: I) -> Vec<Self> { + let mut nodes: Vec<_> = items + .into_iter() + .map(Self::new) + .collect(); + + Self::balance_recursive(&mut nodes); + + nodes + } + + /// Create a balanced subtree. + fn balance_recursive(nodes: &mut [Self]) { + if let Some((node, children)) = nodes.split_first_mut() { + children.sort_by_cached_key(|x| Ordered::new(node.item.distance(&x.item))); + + let (inside, outside) = children.split_at_mut(children.len() / 2); + if let Some(last) = inside.last() { + node.radius = node.item.distance(&last.item).into(); + } + + node.inside_len = inside.len(); + + Self::balance_recursive(inside); + Self::balance_recursive(outside); + } + } +} + +impl<'a, K, V, N> VpSearch<K, &'a V, N> for &'a [FlatVpNode<V>] +where + K: Proximity<&'a V, Distance = V::Distance>, + V: Proximity, + N: Neighborhood<K, &'a V>, +{ + fn item(self) -> &'a V { + &self[0].item + } + + fn radius(self) -> DistanceValue<V> { + self[0].radius + } + + fn inside(self) -> Option<Self> { + let end = self[0].inside_len + 1; + if end > 1 { + Some(&self[1..end]) + } else { + None + } + } + + fn outside(self) -> Option<Self> { + let start = self[0].inside_len + 1; + if start < self.len() { + Some(&self[start..]) + } else { + None + } + } +} + +/// A [vantage-point tree] stored as a flat array. +/// +/// A FlatVpTree is always balanced and usually more efficient than a [VpTree], but doesn't support +/// dynamic updates. +/// +/// [vantage-point tree]: https://en.wikipedia.org/wiki/Vantage-point_tree +pub struct FlatVpTree<T: Proximity> { + nodes: Vec<FlatVpNode<T>>, +} + +impl<T: Proximity> FlatVpTree<T> { + /// Create a balanced tree out of a sequence of items. + pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self { + Self { + nodes: FlatVpNode::balanced(items), + } + } +} + +impl<T> Debug for FlatVpTree<T> +where + T: Proximity + Debug, + DistanceValue<T>: Debug, +{ + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("FlatVpTree") + .field("node", &self.nodes) + .finish() + } +} + +impl<T: Proximity> FromIterator<T> for FlatVpTree<T> { + fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { + Self::balanced(items) + } +} + +/// An iterator that moves values out of a flat VP tree. +pub struct FlatIntoIter<T: Proximity>(std::vec::IntoIter<FlatVpNode<T>>); + +impl<T> Debug for FlatIntoIter<T> +where + T: Proximity + Debug, + DistanceValue<T>: Debug, +{ + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_tuple("FlatIntoIter") + .field(&self.0) + .finish() + } +} + +impl<T: Proximity> Iterator for FlatIntoIter<T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + self.0.next().map(|n| n.item) + } +} + +impl<T: Proximity> IntoIterator for FlatVpTree<T> { + type Item = T; + type IntoIter = FlatIntoIter<T>; + + fn into_iter(self) -> Self::IntoIter { + FlatIntoIter(self.nodes.into_iter()) + } +} + +impl<K, V> NearestNeighbors<K, V> for FlatVpTree<V> +where + K: Proximity<V, Distance = V::Distance>, + V: Proximity, +{ + fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N + where + K: 'k, + V: 'v, + N: Neighborhood<&'k K, &'v V>, + { + if !self.nodes.is_empty() { + self.nodes.as_slice().search(&mut neighborhood); + } + neighborhood + } +} + +impl<K, V> ExactNeighbors<K, V> for FlatVpTree<V> +where + K: Metric<V, Distance = V::Distance>, + V: Metric, +{} + #[cfg(test)] mod tests { use super::*; @@ -368,5 +545,9 @@ mod tests { tree }); } -} + #[test] + fn test_flat_vp_tree() { + test_nearest_neighbors(FlatVpTree::from_iter); + } +} |