summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-06-25 11:44:47 -0400
committerTavian Barnes <tavianator@tavianator.com>2020-06-25 11:44:47 -0400
commit57f4d9dbe851439b24e31977b8c5dc60e246dda3 (patch)
tree68d8726e89cc371f7bb9546832a74c74fce9510e
parentc53a3031f7a8ea0578634d53597c2817f586665b (diff)
downloadacap-57f4d9dbe851439b24e31977b8c5dc60e246dda3.tar.xz
cos: Add prenormalized cosine/angular distances, and an order embedding
-rw-r--r--src/cos.rs348
1 files changed, 332 insertions, 16 deletions
diff --git a/src/cos.rs b/src/cos.rs
index 2fde4ce..3d3219c 100644
--- a/src/cos.rs
+++ b/src/cos.rs
@@ -1,17 +1,19 @@
//! [Cosine distance](https://en.wikipedia.org/wiki/Cosine_similarity).
use crate::coords::Coordinates;
-use crate::distance::{Metric, Proximity};
+use crate::distance::{Distance, Metric, Proximity, Value};
use num_traits::real::Real;
use num_traits::{one, zero};
+use std::cmp::Ordering;
+
/// Compute the [cosine *similarity*] between two points.
///
-/// This is not suitable for implementing [`Proximity::distance()`] because the result is reversed
+/// Use [cosine_distance] instead if you are implementing [Proximity::distance()].
///
/// [cosine *similarity*]: https://en.wikipedia.org/wiki/Cosine_similarity
-/// [`Proximity::distance()`]: Proximity#method.distance
+/// [Proximity::distance()]: Proximity#method.distance
pub fn cosine_similarity<T, U>(x: T, y: U) -> T::Value
where
T: Coordinates,
@@ -91,16 +93,96 @@ where
}
}
+/// Compute the [cosine *similarity*] between two pre-normalized (unit magnitude) points.
+///
+/// Use [prenorm_cosine_distance] instead if you are implementing [Proximity::distance()].
+///
+/// [cosine *similarity*]: https://en.wikipedia.org/wiki/Cosine_similarity
+/// [`Proximity::distance()`]: Proximity#method.distance
+pub fn prenorm_cosine_similarity<T, U>(x: T, y: U) -> T::Value
+where
+ T: Coordinates,
+ U: Coordinates<Value = T::Value>,
+ T::Value: Real,
+{
+ debug_assert!(x.dims() == y.dims());
+
+ let mut dot: T::Value = zero();
+
+ for i in 0..x.dims() {
+ dot += x.coord(i) * y.coord(i);
+ }
+
+ dot
+}
+
+/// Compute the [cosine distance] between two pre-normalized (unit magnitude) points.
+///
+/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity
+pub fn prenorm_cosine_distance<T, U>(x: T, y: U) -> T::Value
+where
+ T: Coordinates,
+ U: Coordinates<Value = T::Value>,
+ T::Value: Real,
+{
+ let one: T::Value = one();
+ one - prenorm_cosine_similarity(x, y)
+}
+
+/// Equips any [coordinate space] with the [cosine distance] function for pre-normalized (unit
+/// magnitude) points.
+///
+/// [coordinate space]: [Coordinates]
+/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub struct PrenormCosine<T>(pub T);
+
+impl<T> Proximity for PrenormCosine<T>
+where
+ T: Coordinates,
+ T::Value: Real,
+{
+ type Distance = T::Value;
+
+ fn distance(&self, other: &Self) -> Self::Distance {
+ prenorm_cosine_distance(&self.0, &other.0)
+ }
+}
+
+impl<T> Proximity<T> for PrenormCosine<T>
+where
+ T: Coordinates,
+ T::Value: Real,
+{
+ type Distance = T::Value;
+
+ fn distance(&self, other: &T) -> Self::Distance {
+ prenorm_cosine_distance(&self.0, other)
+ }
+}
+
+impl<T> Proximity<PrenormCosine<T>> for T
+where
+ T: Coordinates,
+ T::Value: Real,
+{
+ type Distance = T::Value;
+
+ fn distance(&self, other: &PrenormCosine<T>) -> Self::Distance {
+ prenorm_cosine_distance(self, &other.0)
+ }
+}
+
/// Compute the [angular distance] between two points.
///
/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
-pub fn angular_distance<T, U>(x: T, y: U) -> T::Value
+pub fn angular_distance<T, U>(x: T, y: U) -> AngularDistance<T::Value>
where
T: Coordinates,
U: Coordinates<Value = T::Value>,
T::Value: Real,
{
- cosine_similarity(x, y).acos()
+ AngularDistance::from_cos(cosine_similarity(x, y))
}
/// Equips any [coordinate space] with the [angular distance] metric.
@@ -114,11 +196,12 @@ impl<T> Proximity for Angular<T>
where
T: Coordinates,
T::Value: Real,
+ AngularDistance<T::Value>: Distance,
{
- type Distance = T::Value;
+ type Distance = AngularDistance<T::Value>;
fn distance(&self, other: &Self) -> Self::Distance {
- cosine_distance(&self.0, &other.0)
+ angular_distance(&self.0, &other.0)
}
}
@@ -126,8 +209,9 @@ impl<T> Proximity<T> for Angular<T>
where
T: Coordinates,
T::Value: Real,
+ AngularDistance<T::Value>: Distance,
{
- type Distance = T::Value;
+ type Distance = AngularDistance<T::Value>;
fn distance(&self, other: &T) -> Self::Distance {
angular_distance(&self.0, other)
@@ -138,8 +222,9 @@ impl<T> Proximity<Angular<T>> for T
where
T: Coordinates,
T::Value: Real,
+ AngularDistance<T::Value>: Distance,
{
- type Distance = T::Value;
+ type Distance = AngularDistance<T::Value>;
fn distance(&self, other: &Angular<T>) -> Self::Distance {
angular_distance(self, &other.0)
@@ -151,6 +236,7 @@ impl<T> Metric for Angular<T>
where
T: Coordinates,
T::Value: Real,
+ AngularDistance<T::Value>: Distance,
{}
/// Angular distance is a metric.
@@ -158,6 +244,7 @@ impl<T> Metric<T> for Angular<T>
where
T: Coordinates,
T::Value: Real,
+ AngularDistance<T::Value>: Distance,
{}
/// Angular distance is a metric.
@@ -165,12 +252,179 @@ impl<T> Metric<Angular<T>> for T
where
T: Coordinates,
T::Value: Real,
+ AngularDistance<T::Value>: Distance,
+{}
+
+/// Compute the [angular distance] between two points.
+///
+/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
+pub fn prenorm_angular_distance<T, U>(x: T, y: U) -> AngularDistance<T::Value>
+where
+ T: Coordinates,
+ U: Coordinates<Value = T::Value>,
+ T::Value: Real,
+{
+ AngularDistance::from_cos(prenorm_cosine_similarity(x, y))
+}
+
+/// Equips any [coordinate space] with the [angular distance] metric for pre-normalized (unit
+/// magnitude) points.
+///
+/// [coordinate space]: [Coordinates]
+/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub struct PrenormAngular<T>(pub T);
+
+impl<T> Proximity for PrenormAngular<T>
+where
+ T: Coordinates,
+ T::Value: Real,
+ AngularDistance<T::Value>: Distance,
+{
+ type Distance = AngularDistance<T::Value>;
+
+ fn distance(&self, other: &Self) -> Self::Distance {
+ prenorm_angular_distance(&self.0, &other.0)
+ }
+}
+
+impl<T> Proximity<T> for PrenormAngular<T>
+where
+ T: Coordinates,
+ T::Value: Real,
+ AngularDistance<T::Value>: Distance,
+{
+ type Distance = AngularDistance<T::Value>;
+
+ fn distance(&self, other: &T) -> Self::Distance {
+ prenorm_angular_distance(&self.0, other)
+ }
+}
+
+impl<T> Proximity<PrenormAngular<T>> for T
+where
+ T: Coordinates,
+ T::Value: Real,
+ AngularDistance<T::Value>: Distance,
+{
+ type Distance = AngularDistance<T::Value>;
+
+ fn distance(&self, other: &PrenormAngular<T>) -> Self::Distance {
+ prenorm_angular_distance(self, &other.0)
+ }
+}
+
+/// Angular distance is a metric.
+impl<T> Metric for PrenormAngular<T>
+where
+ T: Coordinates,
+ T::Value: Real,
+ AngularDistance<T::Value>: Distance,
{}
+/// Angular distance is a metric.
+impl<T> Metric<T> for PrenormAngular<T>
+where
+ T: Coordinates,
+ T::Value: Real,
+ AngularDistance<T::Value>: Distance,
+{}
+
+/// Angular distance is a metric.
+impl<T> Metric<PrenormAngular<T>> for T
+where
+ T: Coordinates,
+ T::Value: Real,
+ AngularDistance<T::Value>: Distance,
+{}
+
+/// An [angular distance].
+///
+/// This type stores the cosine of the angle, to avoid computing the expensive trancendental
+/// `acos()` function until absolutely necessary.
+///
+/// # use acap::distance::Distance;
+/// # use acap::cos::AngularDistance;
+/// let zero = AngularDistance::from_cos(1.0);
+/// let pi_2 = AngularDistance::from_cos(0.0);
+/// let pi = AngularDistance::from_cos(-1.0);
+/// assert!(zero < pi_2 && pi_2 < pi);
+///
+/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
+#[derive(Clone, Copy, Debug, PartialEq)]
+pub struct AngularDistance<T>(T);
+
+impl<T: Real + Value> AngularDistance<T> {
+ /// Creates an `AngularDistance` from the cosine of an angle.
+ pub fn from_cos(value: T) -> Self {
+ Self(value)
+ }
+
+ /// Get the cosine of this angle.
+ pub fn cos(self) -> T {
+ self.0
+ }
+}
+
+impl<T: PartialOrd> PartialOrd for AngularDistance<T> {
+ fn partial_cmp(&self, other: &AngularDistance<T>) -> Option<Ordering> {
+ // acos() is decreasing, so swap the comparison order
+ other.0.partial_cmp(&self.0)
+ }
+}
+
+macro_rules! impl_distance {
+ ($f:ty) => {
+ impl From<AngularDistance<$f>> for $f {
+ #[inline]
+ fn from(value: AngularDistance<$f>) -> $f {
+ value.0.acos()
+ }
+ }
+
+ impl PartialOrd<$f> for AngularDistance<$f> {
+ #[inline]
+ fn partial_cmp(&self, other: &$f) -> Option<Ordering> {
+ self.value().partial_cmp(other)
+ }
+ }
+
+ impl PartialOrd<AngularDistance<$f>> for $f {
+ #[inline]
+ fn partial_cmp(&self, other: &AngularDistance<$f>) -> Option<Ordering> {
+ self.partial_cmp(&other.value())
+ }
+ }
+
+ impl PartialEq<$f> for AngularDistance<$f> {
+ #[inline]
+ fn eq(&self, other: &$f) -> bool {
+ self.value() == *other
+ }
+ }
+
+ impl PartialEq<AngularDistance<$f>> for $f {
+ #[inline]
+ fn eq(&self, other: &AngularDistance<$f>) -> bool {
+ *self == other.value()
+ }
+ }
+
+ impl Distance for AngularDistance<$f> {
+ type Value = $f;
+ }
+ }
+}
+
+impl_distance!(f32);
+impl_distance!(f64);
+
#[cfg(test)]
mod tests {
use super::*;
+ use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI, SQRT_2};
+
#[test]
fn test_cosine() {
assert_eq!(cosine_distance([3.0, 4.0], [3.0, 4.0]), 0.0);
@@ -180,16 +434,78 @@ mod tests {
}
#[test]
+ fn test_prenorm_cosine() {
+ assert_eq!(prenorm_cosine_distance([0.6, 0.8], [0.6, 0.8]), 0.0);
+ assert_eq!(prenorm_cosine_distance([0.6, 0.8], [-0.8, 0.6]), 1.0);
+ assert_eq!(prenorm_cosine_distance([0.6, 0.8], [-0.6, -0.8]), 2.0);
+ assert_eq!(prenorm_cosine_distance([0.6, 0.8], [0.8, -0.6]), 1.0);
+ }
+
+ #[test]
fn test_angular() {
- use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
+ let zero = angular_distance([3.0, 4.0], [3.0, 4.0]);
+ let pi_4 = Angular([0.0, 1.0]).distance(&Angular([1.0, 1.0]));
+ let pi_2 = Angular([3.0, 4.0]).distance(&[-4.0, 3.0]);
+ let pi = [3.0, 4.0].distance(&Angular([-3.0, -4.0]));
+
+ assert_eq!(zero.cos(), 1.0);
+ assert_eq!(pi_2.cos(), 0.0);
+ assert_eq!(pi.cos(), -1.0);
+
+ assert_eq!(zero, 0.0);
+
+ assert!(zero < pi_4);
+ assert!(zero < pi_2);
+ assert!(zero < pi);
- assert_eq!(angular_distance([3.0, 4.0], [3.0, 4.0]), 0.0);
+ assert!(pi_4 < pi_2);
+ assert!(pi_4 < pi);
- assert!((angular_distance([3.0, 4.0], [-4.0, 3.0]) - FRAC_PI_2).abs() < 1.0e-9);
- assert!((angular_distance([3.0, 4.0], [-3.0, -4.0]) - PI).abs() < 1.0e-9);
- assert!((angular_distance([3.0, 4.0], [4.0, -3.0]) - FRAC_PI_2).abs() < 1.0e-9);
+ assert!(pi_2 < pi);
- assert!((angular_distance([0.0, 1.0], [1.0, 1.0]) - FRAC_PI_4).abs() < 1.0e-9);
+ assert!(FRAC_PI_4 < pi_2);
+ assert!(pi_2 > FRAC_PI_4);
+
+ assert!(pi_2 < PI);
+ assert!(PI > pi_2);
+
+ assert!((pi_4.value() - FRAC_PI_4).abs() < 1.0e-9);
+ assert!((pi_2.value() - FRAC_PI_2).abs() < 1.0e-9);
+ assert!((pi.value() - PI).abs() < 1.0e-9);
}
-}
+ #[test]
+ fn test_prenorm_angular() {
+ let sqrt_2_inv = 1.0 / SQRT_2;
+
+ let zero = prenorm_angular_distance([0.6, 0.8], [0.6, 0.8]);
+ let pi_4 = PrenormAngular([0.0, 1.0]).distance(&PrenormAngular([sqrt_2_inv, sqrt_2_inv]));
+ let pi_2 = PrenormAngular([0.6, 0.8]).distance(&[-0.8, 0.6]);
+ let pi = [0.6, 0.8].distance(&PrenormAngular([-0.6, -0.8]));
+
+ assert_eq!(zero.cos(), 1.0);
+ assert_eq!(pi_2.cos(), 0.0);
+ assert_eq!(pi.cos(), -1.0);
+
+ assert_eq!(zero, 0.0);
+
+ assert!(zero < pi_4);
+ assert!(zero < pi_2);
+ assert!(zero < pi);
+
+ assert!(pi_4 < pi_2);
+ assert!(pi_4 < pi);
+
+ assert!(pi_2 < pi);
+
+ assert!(FRAC_PI_4 < pi_2);
+ assert!(pi_2 > FRAC_PI_4);
+
+ assert!(pi_2 < PI);
+ assert!(PI > pi_2);
+
+ assert!((pi_4.value() - FRAC_PI_4).abs() < 1.0e-9);
+ assert!((pi_2.value() - FRAC_PI_2).abs() < 1.0e-9);
+ assert!((pi.value() - PI).abs() < 1.0e-9);
+ }
+}