From e9a81a6d0df149252164003975addf175d5c6f4b Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Thu, 23 Apr 2020 09:55:13 -0400 Subject: metric/kd: Flatten the tree representation --- src/metric/kd.rs | 113 ++++++++++++++++++++++++++----------------------------- 1 file changed, 54 insertions(+), 59 deletions(-) diff --git a/src/metric/kd.rs b/src/metric/kd.rs index db1b2bd..2caf4a3 100644 --- a/src/metric/kd.rs +++ b/src/metric/kd.rs @@ -66,61 +66,71 @@ where struct KdNode { /// The value stored in this node. item: T, - /// The left subtree, if any. - left: Option>, - /// The right subtree, if any. - right: Option>, + /// The size of the left subtree. + left_len: usize, } impl KdNode { /// Create a new KdNode. - fn new(i: usize, mut items: Vec) -> Option> { - if items.is_empty() { - return None; + fn new(item: T) -> Self { + Self { item, left_len: 0 } + } + + /// Build a k-d tree recursively. + fn build(slice: &mut [KdNode], i: usize) { + if slice.is_empty() { + return; } - items.sort_unstable_by_key(|x| OrderedFloat::from(x.coordinate(i))); - - let mid = items.len() / 2; - let right: Vec = items.drain((mid + 1)..).collect(); - let item = items.pop().unwrap(); - let j = (i + 1) % item.dimensions(); - Some(Box::new(Self { - item, - left: Self::new(j, items), - right: Self::new(j, right), - })) + slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i))); + + let mid = slice.len() / 2; + slice.swap(0, mid); + + let (node, children) = slice.split_first_mut().unwrap(); + let (left, right) = children.split_at_mut(mid); + node.left_len = left.len(); + + let j = (i + 1) % node.item.dimensions(); + Self::build(left, j); + Self::build(right, j); } /// Recursively search for nearest neighbors. - fn search<'a, U, N>(&'a self, i: usize, closest: &mut [f64], neighborhood: &mut N) - where + fn recurse<'a, U, N>( + slice: &'a [KdNode], + i: usize, + closest: &mut [f64], + neighborhood: &mut N, + ) where T: 'a, U: CartesianMetric<&'a T>, N: Neighborhood<&'a T, U>, { - neighborhood.consider(&self.item); + let (node, children) = slice.split_first().unwrap(); + neighborhood.consider(&node.item); let target = neighborhood.target(); let ti = target.coordinate(i); - let si = self.item.coordinate(i); - let j = (i + 1) % self.item.dimensions(); + let ni = node.item.coordinate(i); + let j = (i + 1) % node.item.dimensions(); - let (near, far) = if ti <= si { - (&self.left, &self.right) + let (left, right) = children.split_at(node.left_len); + let (near, far) = if ti <= ni { + (left, right) } else { - (&self.right, &self.left) + (right, left) }; - if let Some(near) = near { - near.search(j, closest, neighborhood); + if !near.is_empty() { + Self::recurse(near, j, closest, neighborhood); } - if let Some(far) = far { + if !far.is_empty() { let saved = closest[i]; - closest[i] = si; + closest[i] = ni; if neighborhood.contains_distance(target.distance(closest)) { - far.search(j, closest, neighborhood); + Self::recurse(far, j, closest, neighborhood); } closest[i] = saved; } @@ -129,16 +139,14 @@ impl KdNode { /// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree). #[derive(Debug)] -pub struct KdTree { - root: Option>>, -} +pub struct KdTree(Vec>); impl FromIterator for KdTree { /// Create a new k-d tree from a set of points. fn from_iter>(items: I) -> Self { - Self { - root: KdNode::new(0, items.into_iter().collect()), - } + let mut nodes: Vec<_> = items.into_iter().map(KdNode::new).collect(); + KdNode::build(nodes.as_mut_slice(), 0); + Self(nodes) } } @@ -153,40 +161,27 @@ where U: 'b, N: Neighborhood<&'a T, &'b U>, { - let target = neighborhood.target(); - let dims = target.dimensions(); - let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect(); + if !self.0.is_empty() { + let target = neighborhood.target(); + let dims = target.dimensions(); + let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect(); - if let Some(root) = &self.root { - root.search(0, &mut closest, &mut neighborhood); + KdNode::recurse(&self.0, 0, &mut closest, &mut neighborhood); } + neighborhood } } /// An iterator that the moves values out of a k-d tree. #[derive(Debug)] -pub struct IntoIter { - stack: Vec>>, -} - -impl IntoIter { - fn new(node: Option>>) -> Self { - Self { - stack: node.into_iter().collect(), - } - } -} +pub struct IntoIter(std::vec::IntoIter>); impl Iterator for IntoIter { type Item = T; fn next(&mut self) -> Option { - self.stack.pop().map(|node| { - self.stack.extend(node.left); - self.stack.extend(node.right); - node.item - }) + self.0.next().map(|n| n.item) } } @@ -195,7 +190,7 @@ impl IntoIterator for KdTree { type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { - IntoIter::new(self.root) + IntoIter(self.0.into_iter()) } } -- cgit v1.2.3