mirror of https://github.com/explosion/spaCy.git
parent
483dddc9bc
commit
b589b945db
|
@ -33,6 +33,7 @@ cdef class PhraseMatcher:
|
|||
cdef attr_id_t attr
|
||||
cdef public object _callbacks
|
||||
cdef public object _patterns
|
||||
cdef public object _docs
|
||||
cdef public object _validate
|
||||
|
||||
def __init__(self, Vocab vocab, max_length=0, attr='ORTH', validate=False):
|
||||
|
@ -55,6 +56,7 @@ cdef class PhraseMatcher:
|
|||
]
|
||||
self.matcher.add('Candidate', None, *abstract_patterns)
|
||||
self._callbacks = {}
|
||||
self._docs = {}
|
||||
self._validate = validate
|
||||
|
||||
def __len__(self):
|
||||
|
@ -64,7 +66,7 @@ cdef class PhraseMatcher:
|
|||
|
||||
RETURNS (int): The number of rules.
|
||||
"""
|
||||
return len(self.phrase_ids)
|
||||
return len(self._docs)
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Check whether the matcher contains rules for a match ID.
|
||||
|
@ -76,7 +78,8 @@ cdef class PhraseMatcher:
|
|||
return ent_id in self._callbacks
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.vocab,), None, None)
|
||||
data = (self.vocab, self._docs, self._callbacks)
|
||||
return (unpickle_matcher, data, None, None)
|
||||
|
||||
def add(self, key, on_match, *docs):
|
||||
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
|
||||
|
@ -89,6 +92,7 @@ cdef class PhraseMatcher:
|
|||
cdef Doc doc
|
||||
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
||||
self._callbacks[ent_id] = on_match
|
||||
self._docs[ent_id] = docs
|
||||
cdef int length
|
||||
cdef int i
|
||||
cdef hash_t phrase_hash
|
||||
|
@ -213,3 +217,11 @@ def get_bilou(length):
|
|||
return [B3_ENT, I3_ENT, L3_ENT]
|
||||
else:
|
||||
return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]
|
||||
|
||||
|
||||
def unpickle_matcher(vocab, docs, callbacks):
|
||||
matcher = PhraseMatcher(vocab)
|
||||
for key, specs in docs.items():
|
||||
callback = callbacks.get(key, None)
|
||||
matcher.add(key, callback, *specs)
|
||||
return matcher
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.matcher import PhraseMatcher
|
||||
from spacy.lang.en import English
|
||||
from spacy.compat import pickle
|
||||
|
||||
|
||||
def test_issue3248_1():
|
||||
"""Test that the PhraseMatcher correctly reports its number of rules, not
|
||||
total number of patterns."""
|
||||
nlp = English()
|
||||
matcher = PhraseMatcher(nlp.vocab)
|
||||
matcher.add("TEST1", None, nlp("a"), nlp("b"), nlp("c"))
|
||||
matcher.add("TEST2", None, nlp("d"))
|
||||
assert len(matcher) == 2
|
||||
|
||||
|
||||
def test_issue3248_2():
|
||||
"""Test that the PhraseMatcher can be pickled correctly."""
|
||||
nlp = English()
|
||||
matcher = PhraseMatcher(nlp.vocab)
|
||||
matcher.add("TEST1", None, nlp("a"), nlp("b"), nlp("c"))
|
||||
matcher.add("TEST2", None, nlp("d"))
|
||||
data = pickle.dumps(matcher)
|
||||
new_matcher = pickle.loads(data)
|
||||
assert len(new_matcher) == len(matcher)
|
Loading…
Reference in New Issue