From f7d950de6df12e729c9beb25ee25ea3dac01afaf Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Thu, 1 Aug 2019 17:13:01 +0200 Subject: [PATCH] ensure the lang of vocab and nlp stay consistent (#4057) * ensure the language of vocab and nlp stay consistent across serialization * equality with = --- spacy/errors.py | 2 + spacy/language.py | 47 ++++++++++++++++++------ spacy/tests/regression/test_issue4054.py | 33 +++++++++++++++++ 3 files changed, 71 insertions(+), 11 deletions(-) create mode 100644 spacy/tests/regression/test_issue4054.py diff --git a/spacy/errors.py b/spacy/errors.py index 1699809a7..945d3364a 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -415,6 +415,8 @@ class Errors(object): "is assigned to a KB identifier.") E149 = ("Error deserializing model. Check that the config used to create the " "component matches the model being loaded.") + E150 = ("The language of the `nlp` object and the `vocab` should be the same, " + "but found '{nlp}' and '{vocab}' respectively.") @add_codes class TempErrors(object): diff --git a/spacy/language.py b/spacy/language.py index bfdd00b79..b839be1f6 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -14,7 +14,8 @@ import srsly from .tokenizer import Tokenizer from .vocab import Vocab from .lemmatizer import Lemmatizer -from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer, EntityLinker +from .pipeline import DependencyParser, Tagger +from .pipeline import Tensorizer, EntityRecognizer, EntityLinker from .pipeline import SimilarityHook, TextCategorizer, Sentencizer from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens from .pipeline import EntityRuler @@ -158,6 +159,9 @@ class Language(object): vocab = factory(self, **meta.get("vocab", {})) if vocab.vectors.name is None: vocab.vectors.name = meta.get("vectors", {}).get("name") + else: + if (self.lang and vocab.lang) and (self.lang != vocab.lang): + raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) self.vocab = vocab if make_doc is True: factory = self.Defaults.create_tokenizer @@ -173,7 +177,10 @@ class Language(object): @property def meta(self): - self._meta.setdefault("lang", self.vocab.lang) + if self.vocab.lang: + self._meta.setdefault("lang", self.vocab.lang) + else: + self._meta.setdefault("lang", self.lang) self._meta.setdefault("name", "model") self._meta.setdefault("version", "0.0.0") self._meta.setdefault("spacy_version", ">={}".format(about.__version__)) @@ -618,7 +625,9 @@ class Language(object): if component_cfg is None: component_cfg = {} docs, golds = zip(*docs_golds) - docs = [self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs] + docs = [ + self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs + ] golds = list(golds) for name, pipe in self.pipeline: kwargs = component_cfg.get(name, {}) @@ -769,8 +778,12 @@ class Language(object): exclude = disable path = util.ensure_path(path) serializers = OrderedDict() - serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(p, exclude=["vocab"]) - serializers["meta.json"] = lambda p: p.open("w").write(srsly.json_dumps(self.meta)) + serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( + p, exclude=["vocab"] + ) + serializers["meta.json"] = lambda p: p.open("w").write( + srsly.json_dumps(self.meta) + ) for name, proc in self.pipeline: if not hasattr(proc, "name"): continue @@ -799,14 +812,20 @@ class Language(object): path = util.ensure_path(path) deserializers = OrderedDict() deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p)) - deserializers["vocab"] = lambda p: self.vocab.from_disk(p) and _fix_pretrained_vectors_name(self) - deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(p, exclude=["vocab"]) + deserializers["vocab"] = lambda p: self.vocab.from_disk( + p + ) and _fix_pretrained_vectors_name(self) + deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk( + p, exclude=["vocab"] + ) for name, proc in self.pipeline: if name in exclude: continue if not hasattr(proc, "from_disk"): continue - deserializers[name] = lambda p, proc=proc: proc.from_disk(p, exclude=["vocab"]) + deserializers[name] = lambda p, proc=proc: proc.from_disk( + p, exclude=["vocab"] + ) if not (path / "vocab").exists() and "vocab" not in exclude: # Convert to list here in case exclude is (default) tuple exclude = list(exclude) + ["vocab"] @@ -852,14 +871,20 @@ class Language(object): exclude = disable deserializers = OrderedDict() deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b)) - deserializers["vocab"] = lambda b: self.vocab.from_bytes(b) and _fix_pretrained_vectors_name(self) - deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(b, exclude=["vocab"]) + deserializers["vocab"] = lambda b: self.vocab.from_bytes( + b + ) and _fix_pretrained_vectors_name(self) + deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( + b, exclude=["vocab"] + ) for name, proc in self.pipeline: if name in exclude: continue if not hasattr(proc, "from_bytes"): continue - deserializers[name] = lambda b, proc=proc: proc.from_bytes(b, exclude=["vocab"]) + deserializers[name] = lambda b, proc=proc: proc.from_bytes( + b, exclude=["vocab"] + ) exclude = util.get_serialization_exclude(deserializers, exclude, kwargs) util.from_bytes(bytes_data, deserializers, exclude) return self diff --git a/spacy/tests/regression/test_issue4054.py b/spacy/tests/regression/test_issue4054.py new file mode 100644 index 000000000..2c9d73751 --- /dev/null +++ b/spacy/tests/regression/test_issue4054.py @@ -0,0 +1,33 @@ +# coding: utf8 +from __future__ import unicode_literals + +from spacy.vocab import Vocab + +import spacy +from spacy.lang.en import English +from spacy.tests.util import make_tempdir +from spacy.util import ensure_path + + +def test_issue4054(en_vocab): + """Test that a new blank model can be made with a vocab from file, + and that serialization does not drop the language at any point.""" + nlp1 = English() + vocab1 = nlp1.vocab + + with make_tempdir() as d: + vocab_dir = ensure_path(d / "vocab") + if not vocab_dir.exists(): + vocab_dir.mkdir() + vocab1.to_disk(vocab_dir) + + vocab2 = Vocab().from_disk(vocab_dir) + print("lang", vocab2.lang) + nlp2 = spacy.blank("en", vocab=vocab2) + + nlp_dir = ensure_path(d / "nlp") + if not nlp_dir.exists(): + nlp_dir.mkdir() + nlp2.to_disk(nlp_dir) + nlp3 = spacy.load(nlp_dir) + assert nlp3.lang == "en"