More abstract

This commit is contained in:
Yomguithereal 2019-04-16 17:07:19 +02:00
parent 63fc03d04e
commit 14700072af
3 changed files with 11 additions and 10 deletions

View File

@ -20,7 +20,6 @@
# http://people.csail.mit.edu/dongdeng/projects/passjoin/index.html
#
from collections import defaultdict
from Levenshtein import distance as levenshtein
from fog.clustering.utils import clusters_from_pairs
@ -187,7 +186,7 @@ def multi_match_aware_substrings(k, string, l, i, pi, li):
current_substring = substring
def passjoin(data, k, sort=True, min_size=2, max_size=float('inf'),
def passjoin(data, k, distance, sort=True, min_size=2, max_size=float('inf'),
mode='connected_components'):
"""
Function returning an iterator over found clusters using the PassJoin
@ -207,6 +206,8 @@ def passjoin(data, k, sort=True, min_size=2, max_size=float('inf'),
data (iterable): Arbitrary iterable containing data points to gather
into clusters. Will be fully consumed.
k (number): Levenshtein distance threshold.
distance (callable): Function tasked to compute the Levenshtein distance
between two points of data.
sort (boolean, optional): whether to sort the data beforehand. Defaults
to False.
min_size (number, optional): minimum number of items in a cluster for
@ -252,7 +253,7 @@ def passjoin(data, k, sort=True, min_size=2, max_size=float('inf'),
# NOTE: first condition is here not to compute Levenshtein
# distance for tiny strings
if (s <= k and l <= k) or levenshtein(A, B) <= k:
if (s <= k and l <= k) or distance(A, B) <= k:
yield (A, B)
# Indexing the string

View File

@ -26,7 +26,6 @@ setup(name='fog',
install_requires=[
'dill>=0.2.7.1',
'phylactery>=0.1.1',
'python-Levenshtein>=0.12.0',
'Unidecode>=1.0.22'
],
entry_points={

View File

@ -2,6 +2,7 @@
# Fog PassJoin Unit Tests
# =============================================================================
import csv
from Levenshtein import distance as levenshtein
from test.clustering.utils import Clusters
from fog.clustering import passjoin
from fog.clustering.passjoin import (
@ -121,28 +122,28 @@ class TestPassJoins(object):
def test_passjoin(self):
# k = 1
clusters = Clusters(passjoin(STRINGS, 1))
clusters = Clusters(passjoin(STRINGS, 1, distance=levenshtein))
assert clusters == CLUSTERS_K1
clusters = Clusters(passjoin(STRINGS, 1, sort=False))
clusters = Clusters(passjoin(STRINGS, 1, distance=levenshtein, sort=False))
assert clusters == CLUSTERS_K1
# k = 2
clusters = Clusters(passjoin(STRINGS, 2))
clusters = Clusters(passjoin(STRINGS, 2, distance=levenshtein))
assert clusters == CLUSTERS_K2
clusters = Clusters(passjoin(STRINGS, 2, sort=False))
clusters = Clusters(passjoin(STRINGS, 2, distance=levenshtein, sort=False))
assert clusters == CLUSTERS_K2
# k = 3
clusters = Clusters(passjoin(STRINGS, 3))
clusters = Clusters(passjoin(STRINGS, 3, distance=levenshtein))
assert clusters == CLUSTERS_K3
clusters = Clusters(passjoin(STRINGS, 3, sort=False))
clusters = Clusters(passjoin(STRINGS, 3, distance=levenshtein, sort=False))
assert clusters == CLUSTERS_K3