diff options
-rw-r--r-- | src/cos.rs | 348 |
1 files changed, 332 insertions, 16 deletions
@@ -1,17 +1,19 @@ //! [Cosine distance](https://en.wikipedia.org/wiki/Cosine_similarity). use crate::coords::Coordinates; -use crate::distance::{Metric, Proximity}; +use crate::distance::{Distance, Metric, Proximity, Value}; use num_traits::real::Real; use num_traits::{one, zero}; +use std::cmp::Ordering; + /// Compute the [cosine *similarity*] between two points. /// -/// This is not suitable for implementing [`Proximity::distance()`] because the result is reversed +/// Use [cosine_distance] instead if you are implementing [Proximity::distance()]. /// /// [cosine *similarity*]: https://en.wikipedia.org/wiki/Cosine_similarity -/// [`Proximity::distance()`]: Proximity#method.distance +/// [Proximity::distance()]: Proximity#method.distance pub fn cosine_similarity<T, U>(x: T, y: U) -> T::Value where T: Coordinates, @@ -91,16 +93,96 @@ where } } +/// Compute the [cosine *similarity*] between two pre-normalized (unit magnitude) points. +/// +/// Use [prenorm_cosine_distance] instead if you are implementing [Proximity::distance()]. +/// +/// [cosine *similarity*]: https://en.wikipedia.org/wiki/Cosine_similarity +/// [`Proximity::distance()`]: Proximity#method.distance +pub fn prenorm_cosine_similarity<T, U>(x: T, y: U) -> T::Value +where + T: Coordinates, + U: Coordinates<Value = T::Value>, + T::Value: Real, +{ + debug_assert!(x.dims() == y.dims()); + + let mut dot: T::Value = zero(); + + for i in 0..x.dims() { + dot += x.coord(i) * y.coord(i); + } + + dot +} + +/// Compute the [cosine distance] between two pre-normalized (unit magnitude) points. +/// +/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity +pub fn prenorm_cosine_distance<T, U>(x: T, y: U) -> T::Value +where + T: Coordinates, + U: Coordinates<Value = T::Value>, + T::Value: Real, +{ + let one: T::Value = one(); + one - prenorm_cosine_similarity(x, y) +} + +/// Equips any [coordinate space] with the [cosine distance] function for pre-normalized (unit +/// magnitude) points. +/// +/// [coordinate space]: [Coordinates] +/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct PrenormCosine<T>(pub T); + +impl<T> Proximity for PrenormCosine<T> +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &Self) -> Self::Distance { + prenorm_cosine_distance(&self.0, &other.0) + } +} + +impl<T> Proximity<T> for PrenormCosine<T> +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &T) -> Self::Distance { + prenorm_cosine_distance(&self.0, other) + } +} + +impl<T> Proximity<PrenormCosine<T>> for T +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &PrenormCosine<T>) -> Self::Distance { + prenorm_cosine_distance(self, &other.0) + } +} + /// Compute the [angular distance] between two points. /// /// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity -pub fn angular_distance<T, U>(x: T, y: U) -> T::Value +pub fn angular_distance<T, U>(x: T, y: U) -> AngularDistance<T::Value> where T: Coordinates, U: Coordinates<Value = T::Value>, T::Value: Real, { - cosine_similarity(x, y).acos() + AngularDistance::from_cos(cosine_similarity(x, y)) } /// Equips any [coordinate space] with the [angular distance] metric. @@ -114,11 +196,12 @@ impl<T> Proximity for Angular<T> where T: Coordinates, T::Value: Real, + AngularDistance<T::Value>: Distance, { - type Distance = T::Value; + type Distance = AngularDistance<T::Value>; fn distance(&self, other: &Self) -> Self::Distance { - cosine_distance(&self.0, &other.0) + angular_distance(&self.0, &other.0) } } @@ -126,8 +209,9 @@ impl<T> Proximity<T> for Angular<T> where T: Coordinates, T::Value: Real, + AngularDistance<T::Value>: Distance, { - type Distance = T::Value; + type Distance = AngularDistance<T::Value>; fn distance(&self, other: &T) -> Self::Distance { angular_distance(&self.0, other) @@ -138,8 +222,9 @@ impl<T> Proximity<Angular<T>> for T where T: Coordinates, T::Value: Real, + AngularDistance<T::Value>: Distance, { - type Distance = T::Value; + type Distance = AngularDistance<T::Value>; fn distance(&self, other: &Angular<T>) -> Self::Distance { angular_distance(self, &other.0) @@ -151,6 +236,7 @@ impl<T> Metric for Angular<T> where T: Coordinates, T::Value: Real, + AngularDistance<T::Value>: Distance, {} /// Angular distance is a metric. @@ -158,6 +244,7 @@ impl<T> Metric<T> for Angular<T> where T: Coordinates, T::Value: Real, + AngularDistance<T::Value>: Distance, {} /// Angular distance is a metric. @@ -165,12 +252,179 @@ impl<T> Metric<Angular<T>> for T where T: Coordinates, T::Value: Real, + AngularDistance<T::Value>: Distance, +{} + +/// Compute the [angular distance] between two points. +/// +/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity +pub fn prenorm_angular_distance<T, U>(x: T, y: U) -> AngularDistance<T::Value> +where + T: Coordinates, + U: Coordinates<Value = T::Value>, + T::Value: Real, +{ + AngularDistance::from_cos(prenorm_cosine_similarity(x, y)) +} + +/// Equips any [coordinate space] with the [angular distance] metric for pre-normalized (unit +/// magnitude) points. +/// +/// [coordinate space]: [Coordinates] +/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct PrenormAngular<T>(pub T); + +impl<T> Proximity for PrenormAngular<T> +where + T: Coordinates, + T::Value: Real, + AngularDistance<T::Value>: Distance, +{ + type Distance = AngularDistance<T::Value>; + + fn distance(&self, other: &Self) -> Self::Distance { + prenorm_angular_distance(&self.0, &other.0) + } +} + +impl<T> Proximity<T> for PrenormAngular<T> +where + T: Coordinates, + T::Value: Real, + AngularDistance<T::Value>: Distance, +{ + type Distance = AngularDistance<T::Value>; + + fn distance(&self, other: &T) -> Self::Distance { + prenorm_angular_distance(&self.0, other) + } +} + +impl<T> Proximity<PrenormAngular<T>> for T +where + T: Coordinates, + T::Value: Real, + AngularDistance<T::Value>: Distance, +{ + type Distance = AngularDistance<T::Value>; + + fn distance(&self, other: &PrenormAngular<T>) -> Self::Distance { + prenorm_angular_distance(self, &other.0) + } +} + +/// Angular distance is a metric. +impl<T> Metric for PrenormAngular<T> +where + T: Coordinates, + T::Value: Real, + AngularDistance<T::Value>: Distance, {} +/// Angular distance is a metric. +impl<T> Metric<T> for PrenormAngular<T> +where + T: Coordinates, + T::Value: Real, + AngularDistance<T::Value>: Distance, +{} + +/// Angular distance is a metric. +impl<T> Metric<PrenormAngular<T>> for T +where + T: Coordinates, + T::Value: Real, + AngularDistance<T::Value>: Distance, +{} + +/// An [angular distance]. +/// +/// This type stores the cosine of the angle, to avoid computing the expensive trancendental +/// `acos()` function until absolutely necessary. +/// +/// # use acap::distance::Distance; +/// # use acap::cos::AngularDistance; +/// let zero = AngularDistance::from_cos(1.0); +/// let pi_2 = AngularDistance::from_cos(0.0); +/// let pi = AngularDistance::from_cos(-1.0); +/// assert!(zero < pi_2 && pi_2 < pi); +/// +/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct AngularDistance<T>(T); + +impl<T: Real + Value> AngularDistance<T> { + /// Creates an `AngularDistance` from the cosine of an angle. + pub fn from_cos(value: T) -> Self { + Self(value) + } + + /// Get the cosine of this angle. + pub fn cos(self) -> T { + self.0 + } +} + +impl<T: PartialOrd> PartialOrd for AngularDistance<T> { + fn partial_cmp(&self, other: &AngularDistance<T>) -> Option<Ordering> { + // acos() is decreasing, so swap the comparison order + other.0.partial_cmp(&self.0) + } +} + +macro_rules! impl_distance { + ($f:ty) => { + impl From<AngularDistance<$f>> for $f { + #[inline] + fn from(value: AngularDistance<$f>) -> $f { + value.0.acos() + } + } + + impl PartialOrd<$f> for AngularDistance<$f> { + #[inline] + fn partial_cmp(&self, other: &$f) -> Option<Ordering> { + self.value().partial_cmp(other) + } + } + + impl PartialOrd<AngularDistance<$f>> for $f { + #[inline] + fn partial_cmp(&self, other: &AngularDistance<$f>) -> Option<Ordering> { + self.partial_cmp(&other.value()) + } + } + + impl PartialEq<$f> for AngularDistance<$f> { + #[inline] + fn eq(&self, other: &$f) -> bool { + self.value() == *other + } + } + + impl PartialEq<AngularDistance<$f>> for $f { + #[inline] + fn eq(&self, other: &AngularDistance<$f>) -> bool { + *self == other.value() + } + } + + impl Distance for AngularDistance<$f> { + type Value = $f; + } + } +} + +impl_distance!(f32); +impl_distance!(f64); + #[cfg(test)] mod tests { use super::*; + use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI, SQRT_2}; + #[test] fn test_cosine() { assert_eq!(cosine_distance([3.0, 4.0], [3.0, 4.0]), 0.0); @@ -180,16 +434,78 @@ mod tests { } #[test] + fn test_prenorm_cosine() { + assert_eq!(prenorm_cosine_distance([0.6, 0.8], [0.6, 0.8]), 0.0); + assert_eq!(prenorm_cosine_distance([0.6, 0.8], [-0.8, 0.6]), 1.0); + assert_eq!(prenorm_cosine_distance([0.6, 0.8], [-0.6, -0.8]), 2.0); + assert_eq!(prenorm_cosine_distance([0.6, 0.8], [0.8, -0.6]), 1.0); + } + + #[test] fn test_angular() { - use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI}; + let zero = angular_distance([3.0, 4.0], [3.0, 4.0]); + let pi_4 = Angular([0.0, 1.0]).distance(&Angular([1.0, 1.0])); + let pi_2 = Angular([3.0, 4.0]).distance(&[-4.0, 3.0]); + let pi = [3.0, 4.0].distance(&Angular([-3.0, -4.0])); + + assert_eq!(zero.cos(), 1.0); + assert_eq!(pi_2.cos(), 0.0); + assert_eq!(pi.cos(), -1.0); + + assert_eq!(zero, 0.0); + + assert!(zero < pi_4); + assert!(zero < pi_2); + assert!(zero < pi); - assert_eq!(angular_distance([3.0, 4.0], [3.0, 4.0]), 0.0); + assert!(pi_4 < pi_2); + assert!(pi_4 < pi); - assert!((angular_distance([3.0, 4.0], [-4.0, 3.0]) - FRAC_PI_2).abs() < 1.0e-9); - assert!((angular_distance([3.0, 4.0], [-3.0, -4.0]) - PI).abs() < 1.0e-9); - assert!((angular_distance([3.0, 4.0], [4.0, -3.0]) - FRAC_PI_2).abs() < 1.0e-9); + assert!(pi_2 < pi); - assert!((angular_distance([0.0, 1.0], [1.0, 1.0]) - FRAC_PI_4).abs() < 1.0e-9); + assert!(FRAC_PI_4 < pi_2); + assert!(pi_2 > FRAC_PI_4); + + assert!(pi_2 < PI); + assert!(PI > pi_2); + + assert!((pi_4.value() - FRAC_PI_4).abs() < 1.0e-9); + assert!((pi_2.value() - FRAC_PI_2).abs() < 1.0e-9); + assert!((pi.value() - PI).abs() < 1.0e-9); } -} + #[test] + fn test_prenorm_angular() { + let sqrt_2_inv = 1.0 / SQRT_2; + + let zero = prenorm_angular_distance([0.6, 0.8], [0.6, 0.8]); + let pi_4 = PrenormAngular([0.0, 1.0]).distance(&PrenormAngular([sqrt_2_inv, sqrt_2_inv])); + let pi_2 = PrenormAngular([0.6, 0.8]).distance(&[-0.8, 0.6]); + let pi = [0.6, 0.8].distance(&PrenormAngular([-0.6, -0.8])); + + assert_eq!(zero.cos(), 1.0); + assert_eq!(pi_2.cos(), 0.0); + assert_eq!(pi.cos(), -1.0); + + assert_eq!(zero, 0.0); + + assert!(zero < pi_4); + assert!(zero < pi_2); + assert!(zero < pi); + + assert!(pi_4 < pi_2); + assert!(pi_4 < pi); + + assert!(pi_2 < pi); + + assert!(FRAC_PI_4 < pi_2); + assert!(pi_2 > FRAC_PI_4); + + assert!(pi_2 < PI); + assert!(PI > pi_2); + + assert!((pi_4.value() - FRAC_PI_4).abs() < 1.0e-9); + assert!((pi_2.value() - FRAC_PI_2).abs() < 1.0e-9); + assert!((pi.value() - PI).abs() < 1.0e-9); + } +} |