From d7b4841c42660c098ab860762fbce2e3706b44b0 Mon Sep 17 00:00:00 2001 From: NicklasXYZ Date: Sun, 17 Mar 2024 00:03:50 +0100 Subject: [PATCH] Add set similarity measures --- src/gleam_community/maths/metrics.gleam | 129 +++++++++++++++++- test/gleam_community/maths/metrics_test.gleam | 22 +++ 2 files changed, 147 insertions(+), 4 deletions(-) diff --git a/src/gleam_community/maths/metrics.gleam b/src/gleam_community/maths/metrics.gleam index e35e336..cd9bfa6 100644 --- a/src/gleam_community/maths/metrics.gleam +++ b/src/gleam_community/maths/metrics.gleam @@ -30,6 +30,9 @@ //// * [`manhatten_distance`](#float_manhatten_distance) //// * [`minkowski_distance`](#minkowski_distance) //// * [`euclidean_distance`](#euclidean_distance) +//// * [`jaccard_index`](#jaccard_index) +//// * [`sorensen_dice_coefficient`](#sorensen_dice_coefficient) +//// * [`tversky_index`](#tversky_index) //// * **Basic statistical measures** //// * [`mean`](#mean) //// * [`median`](#median) @@ -44,6 +47,7 @@ import gleam_community/maths/predicates import gleam_community/maths/conversion import gleam/list import gleam/pair +import gleam/set import gleam/float ///
@@ -292,7 +296,7 @@ pub fn euclidean_distance( } ///
-/// +/// /// Spot a typo? Open an issue! /// ///
@@ -347,7 +351,7 @@ pub fn mean(arr: List(Float)) -> Result(Float, String) { } ///
-/// +/// /// Spot a typo? Open an issue! /// ///
@@ -414,7 +418,7 @@ pub fn median(arr: List(Float)) -> Result(Float, String) { } ///
-/// +/// /// Spot a typo? Open an issue! /// ///
@@ -490,7 +494,7 @@ pub fn variance(arr: List(Float), ddof: Int) -> Result(Float, String) { } ///
-/// +/// /// Spot a typo? Open an issue! /// ///
@@ -555,3 +559,120 @@ pub fn standard_deviation(arr: List(Float), ddof: Int) -> Result(Float, String) } } } + +///
+/// +/// Spot a typo? Open an issue! +/// +///
+/// +///
+/// Example: +/// +/// import gleeunit/should +/// import gleam_community/maths/metrics +/// +/// pub fn example () { +/// } +///
+/// +///
+/// +/// Back to top ↑ +/// +///
+/// +pub fn jaccard_index(aset: set.Set(a), bset: set.Set(a)) -> Float { + let assert Ok(result) = tversky_index(aset, bset, 1.0, 1.0) + result +} + +///
+/// +/// Spot a typo? Open an issue! +/// +///
+/// +///
+/// Example: +/// +/// import gleeunit/should +/// import gleam_community/maths/metrics +/// +/// pub fn example () { +/// } +///
+/// +///
+/// +/// Back to top ↑ +/// +///
+/// +pub fn sorensen_dice_coefficient(aset: set.Set(a), bset: set.Set(a)) -> Float { + let assert Ok(result) = tversky_index(aset, bset, 0.5, 0.5) + result +} + +///
+/// +/// Spot a typo? Open an issue! +/// +///
+/// +/// The Tversky index is a generalization of the Sørensen–Dice coefficient and the Jaccard index. +/// +///
+/// Example: +/// +/// import gleeunit/should +/// import gleam_community/maths/metrics +/// +/// pub fn example () { +/// } +///
+/// +///
+/// +/// Back to top ↑ +/// +///
+/// +pub fn tversky_index( + aset: set.Set(a), + bset: set.Set(a), + alpha: Float, + beta: Float, +) -> Result(Float, String) { + case alpha >=. 0.0, beta >=. 0.0 { + True, True -> { + let intersection: Float = + set.intersection(aset, bset) + |> set.size() + |> conversion.int_to_float() + let difference1: Float = + set.difference(aset, bset) + |> set.size() + |> conversion.int_to_float() + let difference2: Float = + set.difference(bset, aset) + |> set.size() + |> conversion.int_to_float() + intersection + /. { intersection +. alpha *. difference1 +. beta *. difference2 } + |> Ok + } + False, True -> { + "Invalid input argument: alpha < 0. Valid input is alpha >= 0." + |> Error + } + True, False -> { + "Invalid input argument: beta < 0. Valid input is beta >= 0." + |> Error + } + _, _ -> { + "Invalid input argument: alpha < 0 and beta < 0. Valid input is alpha >= 0 and beta >= 0." + |> Error + } + } +} diff --git a/test/gleam_community/maths/metrics_test.gleam b/test/gleam_community/maths/metrics_test.gleam index 8e407e6..cbd8d5e 100644 --- a/test/gleam_community/maths/metrics_test.gleam +++ b/test/gleam_community/maths/metrics_test.gleam @@ -2,6 +2,7 @@ import gleam_community/maths/elementary import gleam_community/maths/metrics import gleam_community/maths/predicates import gleeunit/should +import gleam/set pub fn float_list_norm_test() { let assert Ok(tol) = elementary.power(-10.0, -6.0) @@ -212,3 +213,24 @@ pub fn example_standard_deviation_test() { |> metrics.standard_deviation(ddof) |> should.equal(Ok(1.0)) } + +pub fn example_jaccard_index_test() { + metrics.jaccard_index(set.from_list([]), set.from_list([])) + |> should.equal(0.0) + + let set_a: set.Set(Int) = set.from_list([0, 1, 2, 5, 6, 8, 9]) + let set_b: set.Set(Int) = set.from_list([0, 2, 3, 4, 5, 7, 9]) + metrics.jaccard_index(set_a, set_b) + |> should.equal(4.0 /. 10.0) + + let set_c: set.Set(Int) = set.from_list([0, 1, 2, 3, 4, 5]) + let set_d: set.Set(Int) = set.from_list([6, 7, 8, 9, 10]) + metrics.jaccard_index(set_c, set_d) + |> should.equal(0.0 /. 11.0) + + let set_e: set.Set(String) = set.from_list(["cat", "dog", "hippo", "monkey"]) + let set_f: set.Set(String) = + set.from_list(["monkey", "rhino", "ostrich", "salmon"]) + metrics.jaccard_index(set_e, set_f) + |> should.equal(1.0 /. 7.0) +}