From b589b945dbe9c8660d58a99d17a9e849b987faf0 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Tue, 12 Feb 2019 18:27:54 +0100 Subject: [PATCH] Fix PhraseMatcher pickling and length (resolves #3248) (#3252) --- spacy/matcher/phrasematcher.pyx | 16 ++++++++++++-- spacy/tests/regression/test_issue3248.py | 28 ++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 spacy/tests/regression/test_issue3248.py diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 81c81a008..04c8ad7dd 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -33,6 +33,7 @@ cdef class PhraseMatcher: cdef attr_id_t attr cdef public object _callbacks cdef public object _patterns + cdef public object _docs cdef public object _validate def __init__(self, Vocab vocab, max_length=0, attr='ORTH', validate=False): @@ -55,6 +56,7 @@ cdef class PhraseMatcher: ] self.matcher.add('Candidate', None, *abstract_patterns) self._callbacks = {} + self._docs = {} self._validate = validate def __len__(self): @@ -64,7 +66,7 @@ cdef class PhraseMatcher: RETURNS (int): The number of rules. """ - return len(self.phrase_ids) + return len(self._docs) def __contains__(self, key): """Check whether the matcher contains rules for a match ID. @@ -76,7 +78,8 @@ cdef class PhraseMatcher: return ent_id in self._callbacks def __reduce__(self): - return (self.__class__, (self.vocab,), None, None) + data = (self.vocab, self._docs, self._callbacks) + return (unpickle_matcher, data, None, None) def add(self, key, on_match, *docs): """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID @@ -89,6 +92,7 @@ cdef class PhraseMatcher: cdef Doc doc cdef hash_t ent_id = self.matcher._normalize_key(key) self._callbacks[ent_id] = on_match + self._docs[ent_id] = docs cdef int length cdef int i cdef hash_t phrase_hash @@ -213,3 +217,11 @@ def get_bilou(length): return [B3_ENT, I3_ENT, L3_ENT] else: return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT] + + +def unpickle_matcher(vocab, docs, callbacks): + matcher = PhraseMatcher(vocab) + for key, specs in docs.items(): + callback = callbacks.get(key, None) + matcher.add(key, callback, *specs) + return matcher diff --git a/spacy/tests/regression/test_issue3248.py b/spacy/tests/regression/test_issue3248.py new file mode 100644 index 000000000..8df45bdc0 --- /dev/null +++ b/spacy/tests/regression/test_issue3248.py @@ -0,0 +1,28 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import pytest +from spacy.matcher import PhraseMatcher +from spacy.lang.en import English +from spacy.compat import pickle + + +def test_issue3248_1(): + """Test that the PhraseMatcher correctly reports its number of rules, not + total number of patterns.""" + nlp = English() + matcher = PhraseMatcher(nlp.vocab) + matcher.add("TEST1", None, nlp("a"), nlp("b"), nlp("c")) + matcher.add("TEST2", None, nlp("d")) + assert len(matcher) == 2 + + +def test_issue3248_2(): + """Test that the PhraseMatcher can be pickled correctly.""" + nlp = English() + matcher = PhraseMatcher(nlp.vocab) + matcher.add("TEST1", None, nlp("a"), nlp("b"), nlp("c")) + matcher.add("TEST2", None, nlp("d")) + data = pickle.dumps(matcher) + new_matcher = pickle.loads(data) + assert len(new_matcher) == len(matcher)