summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--benches/benches.rs8
-rw-r--r--src/vp.rs183
2 files changed, 189 insertions, 2 deletions
diff --git a/benches/benches.rs b/benches/benches.rs
index 8791845..a8676ba 100644
--- a/benches/benches.rs
+++ b/benches/benches.rs
@@ -2,7 +2,7 @@
use acap::euclid::Euclidean;
use acap::exhaustive::ExhaustiveSearch;
-use acap::vp::VpTree;
+use acap::vp::{FlatVpTree, VpTree};
use acap::NearestNeighbors;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
@@ -36,6 +36,7 @@ fn bench_from_iter(c: &mut Criterion) {
let mut group = c.benchmark_group("from_iter");
group.bench_function("ExhaustiveSearch", |b| b.iter(|| ExhaustiveSearch::from_iter(points.clone())));
group.bench_function("VpTree", |b| b.iter(|| VpTree::from_iter(points.clone())));
+ group.bench_function("FlatVpTree", |b| b.iter(|| FlatVpTree::from_iter(points.clone())));
group.finish();
}
@@ -45,25 +46,30 @@ fn bench_nearest_neighbors(c: &mut Criterion) {
let exhaustive = ExhaustiveSearch::from_iter(points.clone());
let vp_tree = VpTree::from_iter(points.clone());
+ let flat_vp_tree = FlatVpTree::from_iter(points.clone());
let mut nearest = c.benchmark_group("NearestNeighbors::nearest");
nearest.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.nearest(&target)));
nearest.bench_function("VpTree", |b| b.iter(|| vp_tree.nearest(&target)));
+ nearest.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.nearest(&target)));
nearest.finish();
let mut nearest_within = c.benchmark_group("NearestNeighbors::nearest_within");
nearest_within.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.nearest_within(&target, 0.1)));
nearest_within.bench_function("VpTree", |b| b.iter(|| vp_tree.nearest_within(&target, 0.1)));
+ nearest_within.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.nearest_within(&target, 0.1)));
nearest_within.finish();
let mut k_nearest = c.benchmark_group("NearestNeighbors::k_nearest");
k_nearest.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.k_nearest(&target, 3)));
k_nearest.bench_function("VpTree", |b| b.iter(|| vp_tree.k_nearest(&target, 3)));
+ k_nearest.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.k_nearest(&target, 3)));
k_nearest.finish();
let mut k_nearest_within = c.benchmark_group("NearestNeighbors::k_nearest_within");
k_nearest_within.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.k_nearest_within(&target, 3, 0.1)));
k_nearest_within.bench_function("VpTree", |b| b.iter(|| vp_tree.k_nearest_within(&target, 3, 0.1)));
+ k_nearest_within.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.k_nearest_within(&target, 3, 0.1)));
k_nearest_within.finish();
}
diff --git a/src/vp.rs b/src/vp.rs
index 120e13b..e0645de 100644
--- a/src/vp.rs
+++ b/src/vp.rs
@@ -347,6 +347,183 @@ where
V: Metric,
{}
+/// A node in a flat VP tree.
+#[derive(Debug)]
+struct FlatVpNode<T, R = DistanceValue<T>> {
+ /// The vantage point itself.
+ item: T,
+ /// The radius of this node.
+ radius: R,
+ /// The size of the inside subtree.
+ inside_len: usize,
+}
+
+impl<T: Proximity> FlatVpNode<T> {
+ /// Create a new FlatVpNode.
+ fn new(item: T) -> Self {
+ Self {
+ item,
+ radius: zero(),
+ inside_len: 0,
+ }
+ }
+
+ /// Create a balanced tree.
+ fn balanced<I: IntoIterator<Item = T>>(items: I) -> Vec<Self> {
+ let mut nodes: Vec<_> = items
+ .into_iter()
+ .map(Self::new)
+ .collect();
+
+ Self::balance_recursive(&mut nodes);
+
+ nodes
+ }
+
+ /// Create a balanced subtree.
+ fn balance_recursive(nodes: &mut [Self]) {
+ if let Some((node, children)) = nodes.split_first_mut() {
+ children.sort_by_cached_key(|x| Ordered::new(node.item.distance(&x.item)));
+
+ let (inside, outside) = children.split_at_mut(children.len() / 2);
+ if let Some(last) = inside.last() {
+ node.radius = node.item.distance(&last.item).into();
+ }
+
+ node.inside_len = inside.len();
+
+ Self::balance_recursive(inside);
+ Self::balance_recursive(outside);
+ }
+ }
+}
+
+impl<'a, K, V, N> VpSearch<K, &'a V, N> for &'a [FlatVpNode<V>]
+where
+ K: Proximity<&'a V, Distance = V::Distance>,
+ V: Proximity,
+ N: Neighborhood<K, &'a V>,
+{
+ fn item(self) -> &'a V {
+ &self[0].item
+ }
+
+ fn radius(self) -> DistanceValue<V> {
+ self[0].radius
+ }
+
+ fn inside(self) -> Option<Self> {
+ let end = self[0].inside_len + 1;
+ if end > 1 {
+ Some(&self[1..end])
+ } else {
+ None
+ }
+ }
+
+ fn outside(self) -> Option<Self> {
+ let start = self[0].inside_len + 1;
+ if start < self.len() {
+ Some(&self[start..])
+ } else {
+ None
+ }
+ }
+}
+
+/// A [vantage-point tree] stored as a flat array.
+///
+/// A FlatVpTree is always balanced and usually more efficient than a [VpTree], but doesn't support
+/// dynamic updates.
+///
+/// [vantage-point tree]: https://en.wikipedia.org/wiki/Vantage-point_tree
+pub struct FlatVpTree<T: Proximity> {
+ nodes: Vec<FlatVpNode<T>>,
+}
+
+impl<T: Proximity> FlatVpTree<T> {
+ /// Create a balanced tree out of a sequence of items.
+ pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self {
+ Self {
+ nodes: FlatVpNode::balanced(items),
+ }
+ }
+}
+
+impl<T> Debug for FlatVpTree<T>
+where
+ T: Proximity + Debug,
+ DistanceValue<T>: Debug,
+{
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ f.debug_struct("FlatVpTree")
+ .field("node", &self.nodes)
+ .finish()
+ }
+}
+
+impl<T: Proximity> FromIterator<T> for FlatVpTree<T> {
+ fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
+ Self::balanced(items)
+ }
+}
+
+/// An iterator that moves values out of a flat VP tree.
+pub struct FlatIntoIter<T: Proximity>(std::vec::IntoIter<FlatVpNode<T>>);
+
+impl<T> Debug for FlatIntoIter<T>
+where
+ T: Proximity + Debug,
+ DistanceValue<T>: Debug,
+{
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ f.debug_tuple("FlatIntoIter")
+ .field(&self.0)
+ .finish()
+ }
+}
+
+impl<T: Proximity> Iterator for FlatIntoIter<T> {
+ type Item = T;
+
+ fn next(&mut self) -> Option<T> {
+ self.0.next().map(|n| n.item)
+ }
+}
+
+impl<T: Proximity> IntoIterator for FlatVpTree<T> {
+ type Item = T;
+ type IntoIter = FlatIntoIter<T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ FlatIntoIter(self.nodes.into_iter())
+ }
+}
+
+impl<K, V> NearestNeighbors<K, V> for FlatVpTree<V>
+where
+ K: Proximity<V, Distance = V::Distance>,
+ V: Proximity,
+{
+ fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
+ where
+ K: 'k,
+ V: 'v,
+ N: Neighborhood<&'k K, &'v V>,
+ {
+ if !self.nodes.is_empty() {
+ self.nodes.as_slice().search(&mut neighborhood);
+ }
+ neighborhood
+ }
+}
+
+impl<K, V> ExactNeighbors<K, V> for FlatVpTree<V>
+where
+ K: Metric<V, Distance = V::Distance>,
+ V: Metric,
+{}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -368,5 +545,9 @@ mod tests {
tree
});
}
-}
+ #[test]
+ fn test_flat_vp_tree() {
+ test_nearest_neighbors(FlatVpTree::from_iter);
+ }
+}