Fixing jaccard_intersection_index

This commit is contained in:
Yomguithereal 2018-07-02 21:07:28 +02:00
parent 27441060d4
commit 7fda4ce33c
1 changed files with 19 additions and 15 deletions

View File

@ -30,13 +30,12 @@ def jaccard_intersection_index(data, radius=0.8, key=None, min_size=2,
Args: Args:
data (iterable): Arbitrary iterable containing data points to gather data (iterable): Arbitrary iterable containing data points to gather
into clusters. Will be fully consumed. into clusters. Will be fully consumed.
key (callable): A function returning an item's key. radius (number): Jaccard similarity radius.
keys (callable): A function returning an item's keys. key (callable, optional): Function returning an item's key.
min_size (number, optional): minimum number of items in a cluster for min_size (number, optional): minimum number of items in a cluster for
it to be considered viable. Defaults to 2. it to be considered viable. Defaults to 2.
max_size (number, optional): maximum number of items in a cluster for max_size (number, optional): maximum number of items in a cluster for
it to be considered viable. Defaults to infinity. it to be considered viable. Defaults to infinity.
merge (bool, optional): whether to merge the buckets to form clusters.
Yield: Yield:
list: A viable cluster. list: A viable cluster.
@ -62,16 +61,12 @@ def jaccard_intersection_index(data, radius=0.8, key=None, min_size=2,
for j in bucket: for j in bucket:
intersections[i][j] += 1 intersections[i][j] += 1
intersections[j][i] += 1
bucket.append(i) bucket.append(i)
visited = set()
graph = defaultdict(list) graph = defaultdict(list)
for i, neighbors in intersections.items(): for i, neighbors in intersections.items():
if i in visited:
continue
for j, I in neighbors.items(): for j, I in neighbors.items():
U = sizes[i] + sizes[j] - I U = sizes[i] + sizes[j] - I
@ -80,20 +75,29 @@ def jaccard_intersection_index(data, radius=0.8, key=None, min_size=2,
graph[i].append(j) graph[i].append(j)
graph[j].append(i) graph[j].append(i)
visited.add(j)
visited = set() visited = set()
stack = []
for i, neighbors in graph.items(): for i, neighbors in graph.items():
if i in visited: if i in visited:
continue continue
if len(neighbors) + 1 < min_size: visited.add(i)
continue
if len(neighbors) + 1 > max_size:
continue
visited.update(neighbors) cluster = [data[i]]
stack.extend(neighbors)
while len(stack) != 0:
j = stack.pop()
if j in visited:
continue
cluster.append(data[j])
visited.add(j)
if j in graph:
stack.extend(graph[j])
cluster = [data[i]] + [data[j] for j in neighbors]
yield cluster yield cluster