summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-04-23 14:58:47 -0400
committerTavian Barnes <tavianator@tavianator.com>2020-05-03 00:16:33 -0400
commitb4a39a3f22fac361f6a535d281eee5586078281b (patch)
tree0f2c0f6c584d391d428075351a60c9e8e23908e7
parenta4a75059f302de2a00971f1f485fcf4389710628 (diff)
downloadkd-forest-b4a39a3f22fac361f6a535d281eee5586078281b.tar.xz
metric/forest: Optimize bulk insertion
-rw-r--r--src/metric/forest.rs47
1 files changed, 27 insertions, 20 deletions
diff --git a/src/metric/forest.rs b/src/metric/forest.rs
index f23c451..29b6f55 100644
--- a/src/metric/forest.rs
+++ b/src/metric/forest.rs
@@ -4,7 +4,7 @@ use super::kd::KdTree;
use super::vp::VpTree;
use super::{Metric, NearestNeighbors, Neighborhood};
-use std::iter::{Extend, Flatten, FromIterator};
+use std::iter::{self, Extend, Flatten, FromIterator};
/// A dynamic wrapper for a static nearest neighbor search data structure.
///
@@ -24,23 +24,7 @@ where
/// Add a new item to the forest.
pub fn push(&mut self, item: T) {
- let mut items = vec![item];
-
- for slot in &mut self.0 {
- match slot.take() {
- // Collect the items from any trees we encounter...
- Some(tree) => {
- items.extend(tree);
- }
- // ... and put them all in the first empty slot
- None => {
- *slot = Some(items.into_iter().collect());
- return;
- }
- }
- }
-
- self.0.push(Some(items.into_iter().collect()));
+ self.extend(iter::once(item));
}
/// Get the number of items in the forest.
@@ -60,9 +44,32 @@ where
U: FromIterator<T> + IntoIterator<Item = T>,
{
fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) {
- for item in items {
- self.push(item);
+ let mut vec: Vec<_> = items.into_iter().collect();
+ let new_len = self.len() + vec.len();
+
+ for i in 0.. {
+ let bit = 1 << i;
+
+ if bit > new_len {
+ break;
+ }
+
+ if i >= self.0.len() {
+ self.0.push(None);
+ }
+
+ if new_len & bit == 0 {
+ if let Some(tree) = self.0[i].take() {
+ vec.extend(tree);
+ }
+ } else if self.0[i].is_none() {
+ let offset = vec.len() - bit;
+ self.0[i] = Some(vec.drain(offset..).collect());
+ }
}
+
+ debug_assert!(vec.is_empty());
+ debug_assert!(self.len() == new_len);
}
}