From b7107ac89feee7f1aa1381d3c2978d09919288c2 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 26 Jun 2020 09:23:21 +0200 Subject: [PATCH] Disregard special tag _SP in check for new tag map (#5641) * Skip special tag _SP in check for new tag map In `Tagger.begin_training()` check for new tags aside from `_SP` in the new tag map initialized from the provided gold tuples when determining whether to reinitialize the morphology with the new tag map. * Simplify _SP check --- spacy/pipeline/pipes.pyx | 4 ++-- spacy/tests/pipeline/test_tagger.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 3f40cb545..8f07bf8f7 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -528,10 +528,10 @@ class Tagger(Pipe): new_tag_map[tag] = orig_tag_map[tag] else: new_tag_map[tag] = {POS: X} - if "_SP" in orig_tag_map: - new_tag_map["_SP"] = orig_tag_map["_SP"] cdef Vocab vocab = self.vocab if new_tag_map: + if "_SP" in orig_tag_map: + new_tag_map["_SP"] = orig_tag_map["_SP"] vocab.morphology = Morphology(vocab.strings, new_tag_map, vocab.morphology.lemmatizer, exc=vocab.morphology.exc) diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py index a5bda9090..1681ffeaa 100644 --- a/spacy/tests/pipeline/test_tagger.py +++ b/spacy/tests/pipeline/test_tagger.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals import pytest from spacy.language import Language +from spacy.symbols import POS, NOUN def test_label_types(): @@ -11,3 +12,16 @@ def test_label_types(): nlp.get_pipe("tagger").add_label("A") with pytest.raises(ValueError): nlp.get_pipe("tagger").add_label(9) + + +def test_tagger_begin_training_tag_map(): + """Test that Tagger.begin_training() without gold tuples does not clobber + the tag map.""" + nlp = Language() + tagger = nlp.create_pipe("tagger") + orig_tag_count = len(tagger.labels) + tagger.add_label("A", {"POS": "NOUN"}) + nlp.add_pipe(tagger) + nlp.begin_training() + assert nlp.vocab.morphology.tag_map["A"] == {POS: NOUN} + assert orig_tag_count + 1 == len(nlp.get_pipe("tagger").labels)