diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/metric.rs | 47 | ||||
-rw-r--r-- | src/metric/kd.rs | 6 | ||||
-rw-r--r-- | src/metric/vp.rs | 4 |
3 files changed, 24 insertions, 33 deletions
diff --git a/src/metric.rs b/src/metric.rs index 268aefd..ff996b9 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -6,12 +6,22 @@ pub mod kd; pub mod soft; pub mod vp; -use ordered_float::OrderedFloat; - use std::cmp::Ordering; use std::collections::BinaryHeap; use std::iter::FromIterator; +/// A wrapper that converts a partial ordering into a total one by panicking. +#[derive(Debug, PartialEq, PartialOrd)] +struct Ordered<T>(T); + +impl<T: PartialOrd> Ord for Ordered<T> { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } +} + +impl<T: PartialEq> Eq for Ordered<T> {} + /// An [order embedding](https://en.wikipedia.org/wiki/Order_embedding) for distances. /// /// Implementations of this trait must satisfy, for all non-negative distances `x` and `y`: @@ -22,34 +32,18 @@ use std::iter::FromIterator; /// This trait exists to optimize the common case where distances can be compared more efficiently /// than their exact values can be computed. For example, taking the square root can be avoided /// when comparing Euclidean distances (see [SquaredDistance]). -pub trait Distance: Copy + From<f64> + Into<f64> + Ord {} - -/// A raw numerical distance. -#[derive(Debug, Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] -pub struct RawDistance(OrderedFloat<f64>); - -impl From<f64> for RawDistance { - fn from(value: f64) -> Self { - Self(value.into()) - } -} - -impl From<RawDistance> for f64 { - fn from(value: RawDistance) -> Self { - value.0.into_inner() - } -} +pub trait Distance: Copy + From<f64> + Into<f64> + PartialOrd {} -impl Distance for RawDistance {} +impl Distance for f64 {} /// A squared distance, to avoid computing square roots unless absolutely necessary. -#[derive(Debug, Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] -pub struct SquaredDistance(OrderedFloat<f64>); +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +pub struct SquaredDistance(f64); impl SquaredDistance { /// Create a SquaredDistance from an already squared value. pub fn from_squared(value: f64) -> Self { - Self(value.into()) + Self(value) } } @@ -61,7 +55,7 @@ impl From<f64> for SquaredDistance { impl From<SquaredDistance> for f64 { fn from(value: SquaredDistance) -> Self { - value.0.into_inner().sqrt() + value.0.sqrt() } } @@ -69,8 +63,7 @@ impl Distance for SquaredDistance {} /// A [metric space](https://en.wikipedia.org/wiki/Metric_space). pub trait Metric<T: ?Sized = Self> { - /// The type used to represent distances. Use [RawDistance] to compare the actual values - /// directly, or another type if comparisons can be implemented more efficiently. + /// The type used to represent distances. type Distance: Distance; /// Computes the distance between this point and another point. This function must satisfy @@ -153,7 +146,7 @@ impl<T, D: Distance> PartialOrd for Candidate<T, D> { impl<T, D: Distance> Ord for Candidate<T, D> { fn cmp(&self, other: &Self) -> Ordering { - self.distance.cmp(&other.distance) + self.partial_cmp(other).unwrap() } } diff --git a/src/metric/kd.rs b/src/metric/kd.rs index 2caf4a3..6ea3809 100644 --- a/src/metric/kd.rs +++ b/src/metric/kd.rs @@ -1,8 +1,6 @@ //! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree). -use super::{Metric, NearestNeighbors, Neighborhood}; - -use ordered_float::OrderedFloat; +use super::{Metric, NearestNeighbors, Neighborhood, Ordered}; use std::iter::FromIterator; @@ -82,7 +80,7 @@ impl<T: Cartesian> KdNode<T> { return; } - slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i))); + slice.sort_unstable_by_key(|n| Ordered(n.item.coordinate(i))); let mid = slice.len() / 2; slice.swap(0, mid); diff --git a/src/metric/vp.rs b/src/metric/vp.rs index fae62e5..d6e05df 100644 --- a/src/metric/vp.rs +++ b/src/metric/vp.rs @@ -1,6 +1,6 @@ //! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree). -use super::{Metric, NearestNeighbors, Neighborhood}; +use super::{Metric, NearestNeighbors, Neighborhood, Ordered}; use std::iter::FromIterator; @@ -29,7 +29,7 @@ impl<T: Metric> VpNode<T> { fn build(slice: &mut [VpNode<T>]) { if let Some((node, children)) = slice.split_first_mut() { let item = &node.item; - children.sort_by_cached_key(|n| item.distance(&n.item)); + children.sort_by_cached_key(|n| Ordered(item.distance(&n.item))); let (inside, outside) = children.split_at_mut(children.len() / 2); if let Some(last) = inside.last() { |