From 14700072aff548fff187c082f09a7eaf210079de Mon Sep 17 00:00:00 2001 From: Yomguithereal Date: Tue, 16 Apr 2019 17:07:19 +0200 Subject: [PATCH] More abstract --- fog/clustering/passjoin.py | 7 ++++--- setup.py | 1 - test/clustering/passjoin_test.py | 13 +++++++------ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/fog/clustering/passjoin.py b/fog/clustering/passjoin.py index 1bdd224..d505141 100644 --- a/fog/clustering/passjoin.py +++ b/fog/clustering/passjoin.py @@ -20,7 +20,6 @@ # http://people.csail.mit.edu/dongdeng/projects/passjoin/index.html # from collections import defaultdict -from Levenshtein import distance as levenshtein from fog.clustering.utils import clusters_from_pairs @@ -187,7 +186,7 @@ def multi_match_aware_substrings(k, string, l, i, pi, li): current_substring = substring -def passjoin(data, k, sort=True, min_size=2, max_size=float('inf'), +def passjoin(data, k, distance, sort=True, min_size=2, max_size=float('inf'), mode='connected_components'): """ Function returning an iterator over found clusters using the PassJoin @@ -207,6 +206,8 @@ def passjoin(data, k, sort=True, min_size=2, max_size=float('inf'), data (iterable): Arbitrary iterable containing data points to gather into clusters. Will be fully consumed. k (number): Levenshtein distance threshold. + distance (callable): Function tasked to compute the Levenshtein distance + between two points of data. sort (boolean, optional): whether to sort the data beforehand. Defaults to False. min_size (number, optional): minimum number of items in a cluster for @@ -252,7 +253,7 @@ def passjoin(data, k, sort=True, min_size=2, max_size=float('inf'), # NOTE: first condition is here not to compute Levenshtein # distance for tiny strings - if (s <= k and l <= k) or levenshtein(A, B) <= k: + if (s <= k and l <= k) or distance(A, B) <= k: yield (A, B) # Indexing the string diff --git a/setup.py b/setup.py index 02bbbbc..01cd617 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,6 @@ setup(name='fog', install_requires=[ 'dill>=0.2.7.1', 'phylactery>=0.1.1', - 'python-Levenshtein>=0.12.0', 'Unidecode>=1.0.22' ], entry_points={ diff --git a/test/clustering/passjoin_test.py b/test/clustering/passjoin_test.py index e677ae0..4bdf6d5 100644 --- a/test/clustering/passjoin_test.py +++ b/test/clustering/passjoin_test.py @@ -2,6 +2,7 @@ # Fog PassJoin Unit Tests # ============================================================================= import csv +from Levenshtein import distance as levenshtein from test.clustering.utils import Clusters from fog.clustering import passjoin from fog.clustering.passjoin import ( @@ -121,28 +122,28 @@ class TestPassJoins(object): def test_passjoin(self): # k = 1 - clusters = Clusters(passjoin(STRINGS, 1)) + clusters = Clusters(passjoin(STRINGS, 1, distance=levenshtein)) assert clusters == CLUSTERS_K1 - clusters = Clusters(passjoin(STRINGS, 1, sort=False)) + clusters = Clusters(passjoin(STRINGS, 1, distance=levenshtein, sort=False)) assert clusters == CLUSTERS_K1 # k = 2 - clusters = Clusters(passjoin(STRINGS, 2)) + clusters = Clusters(passjoin(STRINGS, 2, distance=levenshtein)) assert clusters == CLUSTERS_K2 - clusters = Clusters(passjoin(STRINGS, 2, sort=False)) + clusters = Clusters(passjoin(STRINGS, 2, distance=levenshtein, sort=False)) assert clusters == CLUSTERS_K2 # k = 3 - clusters = Clusters(passjoin(STRINGS, 3)) + clusters = Clusters(passjoin(STRINGS, 3, distance=levenshtein)) assert clusters == CLUSTERS_K3 - clusters = Clusters(passjoin(STRINGS, 3, sort=False)) + clusters = Clusters(passjoin(STRINGS, 3, distance=levenshtein, sort=False)) assert clusters == CLUSTERS_K3