diff options
Diffstat (limited to 'src/lib.rs')
-rw-r--r-- | src/lib.rs | 86 |
1 files changed, 40 insertions, 46 deletions
@@ -248,6 +248,9 @@ impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> { mut threshold: Option<D>, heap: &'a mut Vec<Neighbor<V, D>>, ) -> Self { + // A descending array is also a max-heap + heap.reverse(); + if k > 0 && heap.len() == k { let distance = heap[0].distance; if threshold.map_or(true, |t| distance <= t) { @@ -263,20 +266,25 @@ impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> { } } - /// Restore the heap property by raising an entry. - fn bubble_up(&mut self, mut i: usize) { + /// Push a new element into the heap. + fn push(&mut self, item: Neighbor<V, D>) { + let mut i = self.heap.len(); + self.heap.push(item); + while i > 0 { let parent = (i - 1) / 2; - if self.heap[i].distance <= self.heap[parent].distance { + if self.heap[i].distance > self.heap[parent].distance { + self.heap.swap(i, parent); + i = parent; + } else { break; } - self.heap.swap(i, parent); - i = parent; } } - /// Restore the heap property by lowering an entry. - fn bubble_down(&mut self, mut i: usize, len: usize) { + /// Restore the heap property by lowering the root. + fn sink_root(&mut self, len: usize) { + let mut i = 0; let dist = self.heap[i].distance; loop { @@ -295,11 +303,17 @@ impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> { } } + /// Replace the root of the heap with a new element. + fn replace_root(&mut self, item: Neighbor<V, D>) { + self.heap[0] = item; + self.sink_root(self.heap.len()); + } + /// Sort the heap from smallest to largest distance. fn sort(&mut self) { for i in (0..self.heap.len()).rev() { self.heap.swap(0, i); - self.bubble_down(0, i); + self.sink_root(i); } } } @@ -326,11 +340,9 @@ where let neighbor = Neighbor::new(item, distance); if self.heap.len() < self.k { - self.heap.push(neighbor); - self.bubble_up(self.heap.len() - 1); + self.push(neighbor); } else { - self.heap[0] = neighbor; - self.bubble_down(0, self.heap.len()); + self.replace_root(neighbor); } if self.heap.len() == self.k { @@ -381,8 +393,7 @@ pub trait NearestNeighbors<K: Proximity<V>, V = K> { /// The result will be sorted from nearest to farthest. fn k_nearest(&self, target: &K, k: usize) -> Vec<Neighbor<&V, K::Distance>> { let mut neighbors = Vec::with_capacity(k); - self.search(HeapNeighborhood::new(target, k, None, &mut neighbors)) - .sort(); + self.merge_k_nearest(target, k, &mut neighbors); neighbors } @@ -399,52 +410,34 @@ pub trait NearestNeighbors<K: Proximity<V>, V = K> { D: TryInto<K::Distance>, { let mut neighbors = Vec::with_capacity(k); - - if let Ok(distance) = threshold.try_into() { - self.search(HeapNeighborhood::new( - target, - k, - Some(distance), - &mut neighbors, - )) - .sort(); - } - + self.merge_k_nearest_within(target, k, threshold, &mut neighbors); neighbors } - /// Merges up to `k` nearest neighbors into an existing vector. - /// - /// The `neigbors` vector should either be empty, or populated by a previous call to - /// `merge_k_nearest()`. This method assumes a particular ordering that makes merging new - /// results efficient. If you want the results ordered from nearest to farthest, you must sort - /// it yourself. + /// Merges up to `k` nearest neighbors into an existing sorted vector. fn merge_k_nearest<'v>( &'v self, target: &K, k: usize, neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>, ) { - self.search(HeapNeighborhood::new(target, k, None, neighbors)); + self.search(HeapNeighborhood::new(target, k, None, neighbors)) + .sort(); } - /// Merges up to `k` nearest neighbors within the `threshold` into an existing vector. - /// - /// The `neigbors` vector should either be empty, or populated by a previous call to - /// `merge_k_nearest()`. This method assumes a particular ordering that makes merging new - /// results efficient. If you want the results ordered from nearest to farthest, you must sort - /// it yourself. + /// Merges up to `k` nearest neighbors within the `threshold` into an existing sorted vector. fn merge_k_nearest_within<'v, D>( &'v self, target: &K, k: usize, - neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>, threshold: D, + neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>, ) where D: TryInto<K::Distance>, { if let Ok(distance) = threshold.try_into() { - self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors)); + self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors)) + .sort(); } } @@ -464,7 +457,6 @@ pub mod tests { use super::*; use crate::exhaustive::ExhaustiveSearch; - use crate::util::Ordered; use rand::prelude::*; @@ -555,7 +547,6 @@ pub mod tests { let mut neighbors = Vec::new(); index.merge_k_nearest(&target, 3, &mut neighbors); - neighbors.sort_by_key(|n| Ordered::new(n.distance)); assert_eq!( neighbors, vec![ @@ -565,15 +556,18 @@ pub mod tests { ] ); - neighbors.drain(0..2); - index.merge_k_nearest_within(&target, 3, &mut neighbors, 6.0); - neighbors.sort_by_key(|n| Ordered::new(n.distance)); + neighbors = vec![ + Neighbor::new(&target, EuclideanDistance::from_squared(0.0)), + Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), EuclideanDistance::from_squared(25.0)), + Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), EuclideanDistance::from_squared(49.0)), + ]; + index.merge_k_nearest_within(&target, 3, 4.0, &mut neighbors); assert_eq!( neighbors, vec![ + Neighbor::new(&target, 0.0), Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), - Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), ] ); } |