summaryrefslogtreecommitdiffstats
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs86
1 files changed, 40 insertions, 46 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 986c1d3..d6e5579 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -248,6 +248,9 @@ impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> {
mut threshold: Option<D>,
heap: &'a mut Vec<Neighbor<V, D>>,
) -> Self {
+ // A descending array is also a max-heap
+ heap.reverse();
+
if k > 0 && heap.len() == k {
let distance = heap[0].distance;
if threshold.map_or(true, |t| distance <= t) {
@@ -263,20 +266,25 @@ impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> {
}
}
- /// Restore the heap property by raising an entry.
- fn bubble_up(&mut self, mut i: usize) {
+ /// Push a new element into the heap.
+ fn push(&mut self, item: Neighbor<V, D>) {
+ let mut i = self.heap.len();
+ self.heap.push(item);
+
while i > 0 {
let parent = (i - 1) / 2;
- if self.heap[i].distance <= self.heap[parent].distance {
+ if self.heap[i].distance > self.heap[parent].distance {
+ self.heap.swap(i, parent);
+ i = parent;
+ } else {
break;
}
- self.heap.swap(i, parent);
- i = parent;
}
}
- /// Restore the heap property by lowering an entry.
- fn bubble_down(&mut self, mut i: usize, len: usize) {
+ /// Restore the heap property by lowering the root.
+ fn sink_root(&mut self, len: usize) {
+ let mut i = 0;
let dist = self.heap[i].distance;
loop {
@@ -295,11 +303,17 @@ impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> {
}
}
+ /// Replace the root of the heap with a new element.
+ fn replace_root(&mut self, item: Neighbor<V, D>) {
+ self.heap[0] = item;
+ self.sink_root(self.heap.len());
+ }
+
/// Sort the heap from smallest to largest distance.
fn sort(&mut self) {
for i in (0..self.heap.len()).rev() {
self.heap.swap(0, i);
- self.bubble_down(0, i);
+ self.sink_root(i);
}
}
}
@@ -326,11 +340,9 @@ where
let neighbor = Neighbor::new(item, distance);
if self.heap.len() < self.k {
- self.heap.push(neighbor);
- self.bubble_up(self.heap.len() - 1);
+ self.push(neighbor);
} else {
- self.heap[0] = neighbor;
- self.bubble_down(0, self.heap.len());
+ self.replace_root(neighbor);
}
if self.heap.len() == self.k {
@@ -381,8 +393,7 @@ pub trait NearestNeighbors<K: Proximity<V>, V = K> {
/// The result will be sorted from nearest to farthest.
fn k_nearest(&self, target: &K, k: usize) -> Vec<Neighbor<&V, K::Distance>> {
let mut neighbors = Vec::with_capacity(k);
- self.search(HeapNeighborhood::new(target, k, None, &mut neighbors))
- .sort();
+ self.merge_k_nearest(target, k, &mut neighbors);
neighbors
}
@@ -399,52 +410,34 @@ pub trait NearestNeighbors<K: Proximity<V>, V = K> {
D: TryInto<K::Distance>,
{
let mut neighbors = Vec::with_capacity(k);
-
- if let Ok(distance) = threshold.try_into() {
- self.search(HeapNeighborhood::new(
- target,
- k,
- Some(distance),
- &mut neighbors,
- ))
- .sort();
- }
-
+ self.merge_k_nearest_within(target, k, threshold, &mut neighbors);
neighbors
}
- /// Merges up to `k` nearest neighbors into an existing vector.
- ///
- /// The `neigbors` vector should either be empty, or populated by a previous call to
- /// `merge_k_nearest()`. This method assumes a particular ordering that makes merging new
- /// results efficient. If you want the results ordered from nearest to farthest, you must sort
- /// it yourself.
+ /// Merges up to `k` nearest neighbors into an existing sorted vector.
fn merge_k_nearest<'v>(
&'v self,
target: &K,
k: usize,
neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>,
) {
- self.search(HeapNeighborhood::new(target, k, None, neighbors));
+ self.search(HeapNeighborhood::new(target, k, None, neighbors))
+ .sort();
}
- /// Merges up to `k` nearest neighbors within the `threshold` into an existing vector.
- ///
- /// The `neigbors` vector should either be empty, or populated by a previous call to
- /// `merge_k_nearest()`. This method assumes a particular ordering that makes merging new
- /// results efficient. If you want the results ordered from nearest to farthest, you must sort
- /// it yourself.
+ /// Merges up to `k` nearest neighbors within the `threshold` into an existing sorted vector.
fn merge_k_nearest_within<'v, D>(
&'v self,
target: &K,
k: usize,
- neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>,
threshold: D,
+ neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>,
) where
D: TryInto<K::Distance>,
{
if let Ok(distance) = threshold.try_into() {
- self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors));
+ self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors))
+ .sort();
}
}
@@ -464,7 +457,6 @@ pub mod tests {
use super::*;
use crate::exhaustive::ExhaustiveSearch;
- use crate::util::Ordered;
use rand::prelude::*;
@@ -555,7 +547,6 @@ pub mod tests {
let mut neighbors = Vec::new();
index.merge_k_nearest(&target, 3, &mut neighbors);
- neighbors.sort_by_key(|n| Ordered::new(n.distance));
assert_eq!(
neighbors,
vec![
@@ -565,15 +556,18 @@ pub mod tests {
]
);
- neighbors.drain(0..2);
- index.merge_k_nearest_within(&target, 3, &mut neighbors, 6.0);
- neighbors.sort_by_key(|n| Ordered::new(n.distance));
+ neighbors = vec![
+ Neighbor::new(&target, EuclideanDistance::from_squared(0.0)),
+ Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), EuclideanDistance::from_squared(25.0)),
+ Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), EuclideanDistance::from_squared(49.0)),
+ ];
+ index.merge_k_nearest_within(&target, 3, 4.0, &mut neighbors);
assert_eq!(
neighbors,
vec![
+ Neighbor::new(&target, 0.0),
Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
- Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
]
);
}