diff options
-rw-r--r-- | benches/benches.rs | 8 | ||||
-rw-r--r-- | src/kd.rs | 155 |
2 files changed, 162 insertions, 1 deletions
diff --git a/benches/benches.rs b/benches/benches.rs index 8464179..3e67c2a 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -2,7 +2,7 @@ use acap::euclid::Euclidean; use acap::exhaustive::ExhaustiveSearch; -use acap::kd::KdTree; +use acap::kd::{FlatKdTree, KdTree}; use acap::vp::{FlatVpTree, VpTree}; use acap::NearestNeighbors; @@ -39,6 +39,7 @@ fn bench_from_iter(c: &mut Criterion) { group.bench_function("VpTree", |b| b.iter(|| VpTree::from_iter(points.clone()))); group.bench_function("FlatVpTree", |b| b.iter(|| FlatVpTree::from_iter(points.clone()))); group.bench_function("KdTree", |b| b.iter(|| KdTree::from_iter(points.clone()))); + group.bench_function("FlatKdTree", |b| b.iter(|| FlatKdTree::from_iter(points.clone()))); group.finish(); } @@ -50,12 +51,14 @@ fn bench_nearest_neighbors(c: &mut Criterion) { let vp_tree = VpTree::from_iter(points.clone()); let flat_vp_tree = FlatVpTree::from_iter(points.clone()); let kd_tree = KdTree::from_iter(points.clone()); + let flat_kd_tree = FlatKdTree::from_iter(points.clone()); let mut nearest = c.benchmark_group("NearestNeighbors::nearest"); nearest.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.nearest(&target))); nearest.bench_function("VpTree", |b| b.iter(|| vp_tree.nearest(&target))); nearest.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.nearest(&target))); nearest.bench_function("KdTree", |b| b.iter(|| kd_tree.nearest(&target))); + nearest.bench_function("FlatKdTree", |b| b.iter(|| flat_kd_tree.nearest(&target))); nearest.finish(); let mut nearest_within = c.benchmark_group("NearestNeighbors::nearest_within"); @@ -63,6 +66,7 @@ fn bench_nearest_neighbors(c: &mut Criterion) { nearest_within.bench_function("VpTree", |b| b.iter(|| vp_tree.nearest_within(&target, 0.1))); nearest_within.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.nearest_within(&target, 0.1))); nearest_within.bench_function("KdTree", |b| b.iter(|| kd_tree.nearest_within(&target, 0.1))); + nearest_within.bench_function("FlatKdTree", |b| b.iter(|| flat_kd_tree.nearest_within(&target, 0.1))); nearest_within.finish(); let mut k_nearest = c.benchmark_group("NearestNeighbors::k_nearest"); @@ -70,6 +74,7 @@ fn bench_nearest_neighbors(c: &mut Criterion) { k_nearest.bench_function("VpTree", |b| b.iter(|| vp_tree.k_nearest(&target, 3))); k_nearest.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.k_nearest(&target, 3))); k_nearest.bench_function("KdTree", |b| b.iter(|| kd_tree.k_nearest(&target, 3))); + k_nearest.bench_function("FlatKdTree", |b| b.iter(|| flat_kd_tree.k_nearest(&target, 3))); k_nearest.finish(); let mut k_nearest_within = c.benchmark_group("NearestNeighbors::k_nearest_within"); @@ -77,6 +82,7 @@ fn bench_nearest_neighbors(c: &mut Criterion) { k_nearest_within.bench_function("VpTree", |b| b.iter(|| vp_tree.k_nearest_within(&target, 3, 0.1))); k_nearest_within.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.k_nearest_within(&target, 3, 0.1))); k_nearest_within.bench_function("KdTree", |b| b.iter(|| kd_tree.k_nearest_within(&target, 3, 0.1))); + k_nearest_within.bench_function("FlatKdTree", |b| b.iter(|| flat_kd_tree.k_nearest_within(&target, 3, 0.1))); k_nearest_within.finish(); } @@ -333,6 +333,156 @@ where V: Coordinates, {} +/// A node in a flat k-d tree. +#[derive(Debug)] +struct FlatKdNode<T> { + /// The vantage point itself. + item: T, + /// The size of the left subtree. + left_len: usize, +} + +impl<T: Coordinates> FlatKdNode<T> { + /// Create a new FlatKdNode. + fn new(item: T) -> Self { + Self { + item, + left_len: 0, + } + } + + /// Create a balanced tree. + fn balanced<I: IntoIterator<Item = T>>(items: I) -> Vec<Self> { + let mut nodes: Vec<_> = items + .into_iter() + .map(Self::new) + .collect(); + + Self::balance_recursive(&mut nodes, 0); + + nodes + } + + /// Create a balanced subtree. + fn balance_recursive(nodes: &mut [Self], level: usize) { + if !nodes.is_empty() { + nodes.sort_by_cached_key(|x| Ordered::new(x.item.coord(level))); + + let mid = nodes.len() / 2; + nodes.swap(0, mid); + + let (node, children) = nodes.split_first_mut().unwrap(); + let (left, right) = children.split_at_mut(mid); + node.left_len = left.len(); + + let next = (level + 1) % node.item.dims(); + Self::balance_recursive(left, next); + Self::balance_recursive(right, next); + } + } +} + +impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a [FlatKdNode<V>] +where + K: KdProximity<&'a V>, + V: Coordinates, + N: Neighborhood<K, &'a V>, +{ + fn item(self) -> &'a V { + &self[0].item + } + + fn left(self) -> Option<Self> { + let end = self[0].left_len + 1; + if end > 1 { + Some(&self[1..end]) + } else { + None + } + } + + fn right(self) -> Option<Self> { + let start = self[0].left_len + 1; + if start < self.len() { + Some(&self[start..]) + } else { + None + } + } +} + +/// A [k-d tree] stored as a flat array. +/// +/// A FlatKdTree is always balanced and usually more efficient than a [KdTree], but doesn't support +/// dynamic updates. +/// +/// [k-d tree]: https://en.wikipedia.org/wiki/K-d_tree +#[derive(Debug)] +pub struct FlatKdTree<T> { + nodes: Vec<FlatKdNode<T>>, +} + +impl<T: Coordinates> FlatKdTree<T> { + /// Create a balanced tree out of a sequence of items. + pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self { + Self { + nodes: FlatKdNode::balanced(items), + } + } +} + +impl<T: Coordinates> FromIterator<T> for FlatKdTree<T> { + fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { + Self::balanced(items) + } +} + +/// An iterator that moves values out of a flat k-d tree. +#[derive(Debug)] +pub struct FlatIntoIter<T>(std::vec::IntoIter<FlatKdNode<T>>); + +impl<T> Iterator for FlatIntoIter<T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + self.0.next().map(|n| n.item) + } +} + +impl<T> IntoIterator for FlatKdTree<T> { + type Item = T; + type IntoIter = FlatIntoIter<T>; + + fn into_iter(self) -> Self::IntoIter { + FlatIntoIter(self.nodes.into_iter()) + } +} + +impl<K, V> NearestNeighbors<K, V> for FlatKdTree<V> +where + K: KdProximity<V>, + V: Coordinates, +{ + fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N + where + K: 'k, + V: 'v, + N: Neighborhood<&'k K, &'v V>, + { + if !self.nodes.is_empty() { + let mut closest = neighborhood.target().as_vec(); + self.nodes.as_slice().search(0, &mut closest, &mut neighborhood); + } + neighborhood + } +} + +impl<K, V> ExactNeighbors<K, V> for FlatKdTree<V> +where + K: KdMetric<V>, + V: Coordinates, +{} + #[cfg(test)] mod tests { use super::*; @@ -354,4 +504,9 @@ mod tests { tree }); } + + #[test] + fn test_flat_kd_tree() { + test_nearest_neighbors(FlatKdTree::from_iter); + } } |