mirror of https://github.com/explosion/spaCy.git
Switch to new add API in PhraseMatcher unpickle
This commit is contained in:
parent
ae1c179f3a
commit
e06ca7ea24
|
@ -332,7 +332,7 @@ def unpickle_matcher(vocab, docs, callbacks, attr):
|
|||
matcher = PhraseMatcher(vocab, attr=attr)
|
||||
for key, specs in docs.items():
|
||||
callback = callbacks.get(key, None)
|
||||
matcher.add(key, callback, *specs)
|
||||
matcher.add(key, specs, on_match=callback)
|
||||
return matcher
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
import srsly
|
||||
from mock import Mock
|
||||
from spacy.matcher import PhraseMatcher
|
||||
from spacy.tokens import Doc
|
||||
|
@ -266,3 +267,26 @@ def test_phrase_matcher_basic_check(en_vocab):
|
|||
pattern = Doc(en_vocab, words=["hello", "world"])
|
||||
with pytest.raises(ValueError):
|
||||
matcher.add("TEST", pattern)
|
||||
|
||||
|
||||
def test_phrase_matcher_pickle(en_vocab):
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
mock = Mock()
|
||||
matcher.add("TEST", [Doc(en_vocab, words=["test"])])
|
||||
matcher.add("TEST2", [Doc(en_vocab, words=["test2"])], on_match=mock)
|
||||
doc = Doc(en_vocab, words=["these", "are", "tests", ":", "test", "test2"])
|
||||
assert len(matcher) == 2
|
||||
|
||||
b = srsly.pickle_dumps(matcher)
|
||||
matcher_unpickled = srsly.pickle_loads(b)
|
||||
|
||||
# call after pickling to avoid recursion error related to mock
|
||||
matches = matcher(doc)
|
||||
matches_unpickled = matcher_unpickled(doc)
|
||||
|
||||
assert len(matcher) == len(matcher_unpickled)
|
||||
assert matches == matches_unpickled
|
||||
|
||||
# clunky way to vaguely check that callback is unpickled
|
||||
(vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1]
|
||||
assert isinstance(callbacks.get("TEST2"), Mock)
|
||||
|
|
Loading…
Reference in New Issue