diff --git a/spacy/errors.py b/spacy/errors.py index a6b199a50..02656e0e7 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -476,6 +476,8 @@ class Errors(object): E168 = ("Unknown field: {field}") E169 = ("Can't find module: {module}") E170 = ("Cannot apply transition {name}: invalid for the current state.") + E171 = ("Matcher.add received invalid on_match callback argument: expected " + "callable or None, but got: {arg_type}") @add_codes diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index fe6ccc781..950a7b977 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -103,6 +103,8 @@ cdef class Matcher: *patterns (list): List of token descriptions. """ errors = {} + if on_match is not None and not hasattr(on_match, "__call__"): + raise ValueError(Errors.E171.format(arg_type=type(on_match))) for i, pattern in enumerate(patterns): if len(pattern) == 0: raise ValueError(Errors.E012.format(key=key)) diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index df35a1be2..0d640e1a2 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -410,3 +410,11 @@ def test_matcher_schema_token_attributes(en_vocab, pattern, text): assert len(matcher) == 1 matches = matcher(doc) assert len(matches) == 1 + + +def test_matcher_valid_callback(en_vocab): + """Test that on_match can only be None or callable.""" + matcher = Matcher(en_vocab) + with pytest.raises(ValueError): + matcher.add("TEST", [], [{"TEXT": "test"}]) + matcher(Doc(en_vocab, words=["test"]))