diff --git a/spacy/tests/regression/test_issue4190.py b/spacy/tests/regression/test_issue4190.py new file mode 100644 index 000000000..464996705 --- /dev/null +++ b/spacy/tests/regression/test_issue4190.py @@ -0,0 +1,57 @@ +# coding: utf8 +from __future__ import unicode_literals + +from spacy.lang.en import English + +import spacy +from spacy.tokenizer import Tokenizer + +from spacy.tests.util import make_tempdir + + +def test_issue4190(): + test_string = "Test c." + + # Load default language + nlp_1 = English() + doc_1a = nlp_1(test_string) + result_1a = [token.text for token in doc_1a] + + # Modify tokenizer + customize_tokenizer(nlp_1) + doc_1b = nlp_1(test_string) + result_1b = [token.text for token in doc_1b] + + # Save and Reload + with make_tempdir() as model_dir: + nlp_1.to_disk(model_dir) + nlp_2 = spacy.load(model_dir) + + # This should be the modified tokenizer + doc_2 = nlp_2(test_string) + result_2 = [token.text for token in doc_2] + + assert result_1b == result_2 + + +def customize_tokenizer(nlp): + prefix_re = spacy.util.compile_prefix_regex(nlp.Defaults.prefixes) + suffix_re = spacy.util.compile_suffix_regex(nlp.Defaults.suffixes) + infix_re = spacy.util.compile_infix_regex(nlp.Defaults.infixes) + + # remove all exceptions where a single letter is followed by a period (e.g. 'h.') + exceptions = { + k: v + for k, v in dict(nlp.Defaults.tokenizer_exceptions).items() + if not (len(k) == 2 and k[1] == ".") + } + new_tokenizer = Tokenizer( + nlp.vocab, + exceptions, + prefix_search=prefix_re.search, + suffix_search=suffix_re.search, + infix_finditer=infix_re.finditer, + token_match=nlp.tokenizer.token_match, + ) + + nlp.tokenizer = new_tokenizer diff --git a/spacy/tokenizer.pyx b/spacy/tokenizer.pyx index f19f851c7..19029ec05 100644 --- a/spacy/tokenizer.pyx +++ b/spacy/tokenizer.pyx @@ -441,8 +441,13 @@ cdef class Tokenizer: self.infix_finditer = re.compile(data["infix_finditer"]).finditer if data.get("token_match"): self.token_match = re.compile(data["token_match"]).match - for string, substrings in data.get("rules", {}).items(): - self.add_special_case(string, substrings) + if data.get("rules"): + # make sure to hard reset the cache to remove data from the default exceptions + self._rules = {} + self._cache = PreshMap() + for string, substrings in data.get("rules", {}).items(): + self.add_special_case(string, substrings) + return self