Disregard special tag _SP in check for new tag map (#5641)

* Skip special tag  _SP in check for new tag map

In `Tagger.begin_training()` check for new tags aside from `_SP` in the
new tag map initialized from the provided gold tuples when determining
whether to reinitialize the morphology with the new tag map.

* Simplify _SP check
This commit is contained in:
Adriane Boyd 2020-06-26 09:23:21 +02:00 committed by GitHub
parent fd4287c178
commit b7107ac89f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -528,10 +528,10 @@ class Tagger(Pipe):
new_tag_map[tag] = orig_tag_map[tag] new_tag_map[tag] = orig_tag_map[tag]
else: else:
new_tag_map[tag] = {POS: X} new_tag_map[tag] = {POS: X}
if "_SP" in orig_tag_map:
new_tag_map["_SP"] = orig_tag_map["_SP"]
cdef Vocab vocab = self.vocab cdef Vocab vocab = self.vocab
if new_tag_map: if new_tag_map:
if "_SP" in orig_tag_map:
new_tag_map["_SP"] = orig_tag_map["_SP"]
vocab.morphology = Morphology(vocab.strings, new_tag_map, vocab.morphology = Morphology(vocab.strings, new_tag_map,
vocab.morphology.lemmatizer, vocab.morphology.lemmatizer,
exc=vocab.morphology.exc) exc=vocab.morphology.exc)

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
import pytest import pytest
from spacy.language import Language from spacy.language import Language
from spacy.symbols import POS, NOUN
def test_label_types(): def test_label_types():
@ -11,3 +12,16 @@ def test_label_types():
nlp.get_pipe("tagger").add_label("A") nlp.get_pipe("tagger").add_label("A")
with pytest.raises(ValueError): with pytest.raises(ValueError):
nlp.get_pipe("tagger").add_label(9) nlp.get_pipe("tagger").add_label(9)
def test_tagger_begin_training_tag_map():
"""Test that Tagger.begin_training() without gold tuples does not clobber
the tag map."""
nlp = Language()
tagger = nlp.create_pipe("tagger")
orig_tag_count = len(tagger.labels)
tagger.add_label("A", {"POS": "NOUN"})
nlp.add_pipe(tagger)
nlp.begin_training()
assert nlp.vocab.morphology.tag_map["A"] == {POS: NOUN}
assert orig_tag_count + 1 == len(nlp.get_pipe("tagger").labels)