diff options
Diffstat (limited to 'src/metric')
-rw-r--r-- | src/metric/forest.rs | 86 |
1 files changed, 48 insertions, 38 deletions
diff --git a/src/metric/forest.rs b/src/metric/forest.rs index 29b6f55..47eb413 100644 --- a/src/metric/forest.rs +++ b/src/metric/forest.rs @@ -4,14 +4,24 @@ use super::kd::KdTree; use super::vp::VpTree; use super::{Metric, NearestNeighbors, Neighborhood}; -use std::iter::{self, Extend, Flatten, FromIterator}; +use std::iter::{self, Extend, FromIterator}; + +/// The number of bits dedicated to the flat buffer. +const BUFFER_BITS: usize = 6; +/// The maximum size of the buffer. +const BUFFER_SIZE: usize = 1 << BUFFER_BITS; /// A dynamic wrapper for a static nearest neighbor search data structure. /// /// This type applies [dynamization](https://en.wikipedia.org/wiki/Dynamization) to an arbitrary /// nearest neighbor search structure `T`, allowing new items to be added dynamically. #[derive(Debug)] -pub struct Forest<T>(Vec<Option<T>>); +pub struct Forest<T: IntoIterator> { + /// A flat buffer used for the first few items, to avoid repeatedly rebuilding small trees. + buffer: Vec<T::Item>, + /// The trees of the forest, with sizes in geometric progression. + trees: Vec<Option<T>>, +} impl<T, U> Forest<U> where @@ -19,7 +29,10 @@ where { /// Create a new empty forest. pub fn new() -> Self { - Self(Vec::new()) + Self { + buffer: Vec::new(), + trees: Vec::new(), + } } /// Add a new item to the forest. @@ -29,10 +42,10 @@ where /// Get the number of items in the forest. pub fn len(&self) -> usize { - let mut len = 0; - for (i, slot) in self.0.iter().enumerate() { + let mut len = self.buffer.len(); + for (i, slot) in self.trees.iter().enumerate() { if slot.is_some() { - len |= 1 << i; + len += 1 << (i + BUFFER_BITS); } } len @@ -44,32 +57,36 @@ where U: FromIterator<T> + IntoIterator<Item = T>, { fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) { - let mut vec: Vec<_> = items.into_iter().collect(); - let new_len = self.len() + vec.len(); + self.buffer.extend(items); + if self.buffer.len() < BUFFER_SIZE { + return; + } + + let len = self.len(); for i in 0.. { - let bit = 1 << i; + let bit = 1 << (i + BUFFER_BITS); - if bit > new_len { + if bit > len { break; } - if i >= self.0.len() { - self.0.push(None); + if i >= self.trees.len() { + self.trees.push(None); } - if new_len & bit == 0 { - if let Some(tree) = self.0[i].take() { - vec.extend(tree); + if len & bit == 0 { + if let Some(tree) = self.trees[i].take() { + self.buffer.extend(tree); } - } else if self.0[i].is_none() { - let offset = vec.len() - bit; - self.0[i] = Some(vec.drain(offset..).collect()); + } else if self.trees[i].is_none() { + let offset = self.buffer.len() - bit; + self.trees[i] = Some(self.buffer.drain(offset..).collect()); } } - debug_assert!(vec.is_empty()); - debug_assert!(self.len() == new_len); + debug_assert!(self.buffer.len() < BUFFER_SIZE); + debug_assert!(self.len() == len); } } @@ -84,25 +101,13 @@ where } } -type IntoIterImpl<T> = Flatten<Flatten<std::vec::IntoIter<Option<T>>>>; - -/// An iterator that moves items out of a forest. -pub struct IntoIter<T: IntoIterator>(IntoIterImpl<T>); - -impl<T: IntoIterator> Iterator for IntoIter<T> { - type Item = T::Item; - - fn next(&mut self) -> Option<Self::Item> { - self.0.next() - } -} - impl<T: IntoIterator> IntoIterator for Forest<T> { type Item = T::Item; - type IntoIter = IntoIter<T>; + type IntoIter = std::vec::IntoIter<T::Item>; - fn into_iter(self) -> Self::IntoIter { - IntoIter(self.0.into_iter().flatten().flatten()) + fn into_iter(mut self) -> Self::IntoIter { + self.buffer.extend(self.trees.into_iter().flatten().flatten()); + self.buffer.into_iter() } } @@ -110,14 +115,19 @@ impl<T, U, V> NearestNeighbors<T, U> for Forest<V> where U: Metric<T>, V: NearestNeighbors<T, U>, + V: IntoIterator<Item = T>, { - fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N where T: 'a, U: 'b, N: Neighborhood<&'a T, &'b U>, { - self.0 + for item in &self.buffer { + neighborhood.consider(item); + } + + self.trees .iter() .flatten() .fold(neighborhood, |n, t| t.search(n)) |