From 935ac53ee3cc498b372379db1527f3d522b13b94 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 14 Jul 2015 03:20:09 +0200 Subject: [PATCH] * Extend count_by method --- spacy/tokens/doc.pyx | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 8d6266dea..737fe3b8c 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -218,7 +218,7 @@ cdef class Doc: output[i, j] = get_token_attr(&self.data[i], feature) return output - def count_by(self, attr_id_t attr_id, exclude=None): + def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None): """Produce a dict of {attribute (int): count (ints)} frequencies, keyed by the values of the given attribute ID. @@ -236,14 +236,24 @@ cdef class Doc: cdef int i cdef attr_t attr cdef size_t count - - cdef PreshCounter counts = PreshCounter(2 ** 8) - for i in range(self.length): - if exclude is not None and exclude(self[i]): - continue - attr = get_token_attr(&self.data[i], attr_id) - counts.inc(attr, 1) - return dict(counts) + + if counts is None: + counts = PreshCounter(self.length) + output_dict = True + else: + output_dict = False + # Take this check out of the loop, for a bit of extra speed + if exclude is None: + for i in range(self.length): + attr = get_token_attr(&self.data[i], attr_id) + counts.inc(attr, 1) + else: + for i in range(self.length): + if not exclude(self[i]): + attr = get_token_attr(&self.data[i], attr_id) + counts.inc(attr, 1) + if output_dict: + return dict(counts) def _realloc(self, new_size): self.max_length = new_size