diff options
-rw-r--r-- | src/metric/vp.rs | 133 |
1 files changed, 51 insertions, 82 deletions
diff --git a/src/metric/vp.rs b/src/metric/vp.rs index 8d5b091..fae62e5 100644 --- a/src/metric/vp.rs +++ b/src/metric/vp.rs @@ -11,78 +11,62 @@ struct VpNode<T> { item: T, /// The radius of this node. radius: f64, - /// The subtree inside the radius, if any. - inside: Option<Box<Self>>, - /// The subtree outside the radius, if any. - outside: Option<Box<Self>>, + /// The size of the subtree inside the radius. + inside_len: usize, } impl<T: Metric> VpNode<T> { /// Create a new VpNode. - fn new(mut items: Vec<T>) -> Option<Box<Self>> { - if items.is_empty() { - return None; + fn new(item: T) -> Self { + Self { + item, + radius: 0.0, + inside_len: 0, } + } - let item = items.pop().unwrap(); - - items.sort_by_cached_key(|a| item.distance(a)); - - let mid = items.len() / 2; - let outside: Vec<T> = items.drain(mid..).collect(); + /// Build a VP tree recursively. + fn build(slice: &mut [VpNode<T>]) { + if let Some((node, children)) = slice.split_first_mut() { + let item = &node.item; + children.sort_by_cached_key(|n| item.distance(&n.item)); - let radius = items.last().map(|l| item.distance(l).into()).unwrap_or(0.0); + let (inside, outside) = children.split_at_mut(children.len() / 2); + if let Some(last) = inside.last() { + node.radius = item.distance(&last.item).into(); + } + node.inside_len = inside.len(); - Some(Box::new(Self { - item, - radius, - inside: Self::new(items), - outside: Self::new(outside), - })) + Self::build(inside); + Self::build(outside); + } } -} -trait VpSearch<'a, T, U, N> { /// Recursively search for nearest neighbors. - fn search(&'a self, neighborhood: &mut N); - - /// Search the inside subtree. - fn search_inside(&'a self, distance: f64, neighborhood: &mut N); - - /// Search the outside subtree. - fn search_outside(&'a self, distance: f64, neighborhood: &mut N); -} + fn recurse<'a, U, N>(slice: &'a [VpNode<T>], neighborhood: &mut N) + where + T: 'a, + U: Metric<&'a T>, + N: Neighborhood<&'a T, U>, + { + let (node, children) = slice.split_first().unwrap(); + let (inside, outside) = children.split_at(node.inside_len); -impl<'a, T, U, N> VpSearch<'a, T, U, N> for VpNode<T> -where - T: 'a, - U: Metric<&'a T>, - N: Neighborhood<&'a T, U>, -{ - fn search(&'a self, neighborhood: &mut N) { - let distance = neighborhood.consider(&self.item).into(); + let distance = neighborhood.consider(&node.item).into(); - if distance <= self.radius { - self.search_inside(distance, neighborhood); - self.search_outside(distance, neighborhood); + if distance <= node.radius { + if !inside.is_empty() && neighborhood.contains(distance - node.radius) { + Self::recurse(inside, neighborhood); + } + if !outside.is_empty() && neighborhood.contains(node.radius - distance) { + Self::recurse(outside, neighborhood); + } } else { - self.search_outside(distance, neighborhood); - self.search_inside(distance, neighborhood); - } - } - - fn search_inside(&'a self, distance: f64, neighborhood: &mut N) { - if let Some(inside) = &self.inside { - if neighborhood.contains(distance - self.radius) { - inside.search(neighborhood); + if !outside.is_empty() && neighborhood.contains(node.radius - distance) { + Self::recurse(outside, neighborhood); } - } - } - - fn search_outside(&'a self, distance: f64, neighborhood: &mut N) { - if let Some(outside) = &self.outside { - if neighborhood.contains(self.radius - distance) { - outside.search(neighborhood); + if !inside.is_empty() && neighborhood.contains(distance - node.radius) { + Self::recurse(inside, neighborhood); } } } @@ -90,15 +74,13 @@ where /// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree). #[derive(Debug)] -pub struct VpTree<T> { - root: Option<Box<VpNode<T>>>, -} +pub struct VpTree<T>(Vec<VpNode<T>>); impl<T: Metric> FromIterator<T> for VpTree<T> { fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { - Self { - root: VpNode::new(items.into_iter().collect::<Vec<_>>()), - } + let mut nodes: Vec<_> = items.into_iter().map(VpNode::new).collect(); + VpNode::build(nodes.as_mut_slice()); + Self(nodes) } } @@ -113,36 +95,23 @@ where U: 'b, N: Neighborhood<&'a T, &'b U>, { - if let Some(root) = &self.root { - root.search(&mut neighborhood); + if !self.0.is_empty() { + VpNode::recurse(&self.0, &mut neighborhood); } + neighborhood } } /// An iterator that moves values out of a VP tree. #[derive(Debug)] -pub struct IntoIter<T> { - stack: Vec<Box<VpNode<T>>>, -} - -impl<T> IntoIter<T> { - fn new(node: Option<Box<VpNode<T>>>) -> Self { - Self { - stack: node.into_iter().collect(), - } - } -} +pub struct IntoIter<T>(std::vec::IntoIter<VpNode<T>>); impl<T> Iterator for IntoIter<T> { type Item = T; fn next(&mut self) -> Option<T> { - self.stack.pop().map(|node| { - self.stack.extend(node.inside); - self.stack.extend(node.outside); - node.item - }) + self.0.next().map(|n| n.item) } } @@ -151,7 +120,7 @@ impl<T> IntoIterator for VpTree<T> { type IntoIter = IntoIter<T>; fn into_iter(self) -> Self::IntoIter { - IntoIter::new(self.root) + IntoIter(self.0.into_iter()) } } |