From e06ca7ea24d151b77719b02f7430d724aa27a406 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 25 May 2020 10:13:56 +0200 Subject: [PATCH] Switch to new add API in PhraseMatcher unpickle --- spacy/matcher/phrasematcher.pyx | 2 +- spacy/tests/matcher/test_phrase_matcher.py | 24 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index b66ec35b8..00c3357f5 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -332,7 +332,7 @@ def unpickle_matcher(vocab, docs, callbacks, attr): matcher = PhraseMatcher(vocab, attr=attr) for key, specs in docs.items(): callback = callbacks.get(key, None) - matcher.add(key, callback, *specs) + matcher.add(key, specs, on_match=callback) return matcher diff --git a/spacy/tests/matcher/test_phrase_matcher.py b/spacy/tests/matcher/test_phrase_matcher.py index 7a6585e06..60aa584ef 100644 --- a/spacy/tests/matcher/test_phrase_matcher.py +++ b/spacy/tests/matcher/test_phrase_matcher.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import pytest +import srsly from mock import Mock from spacy.matcher import PhraseMatcher from spacy.tokens import Doc @@ -266,3 +267,26 @@ def test_phrase_matcher_basic_check(en_vocab): pattern = Doc(en_vocab, words=["hello", "world"]) with pytest.raises(ValueError): matcher.add("TEST", pattern) + + +def test_phrase_matcher_pickle(en_vocab): + matcher = PhraseMatcher(en_vocab) + mock = Mock() + matcher.add("TEST", [Doc(en_vocab, words=["test"])]) + matcher.add("TEST2", [Doc(en_vocab, words=["test2"])], on_match=mock) + doc = Doc(en_vocab, words=["these", "are", "tests", ":", "test", "test2"]) + assert len(matcher) == 2 + + b = srsly.pickle_dumps(matcher) + matcher_unpickled = srsly.pickle_loads(b) + + # call after pickling to avoid recursion error related to mock + matches = matcher(doc) + matches_unpickled = matcher_unpickled(doc) + + assert len(matcher) == len(matcher_unpickled) + assert matches == matches_unpickled + + # clunky way to vaguely check that callback is unpickled + (vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1] + assert isinstance(callbacks.get("TEST2"), Mock)