diff --git a/spacy/tagger.pyx b/spacy/tagger.pyx index 6b895a2ce..b2bc344bb 100644 --- a/spacy/tagger.pyx +++ b/spacy/tagger.pyx @@ -15,6 +15,7 @@ from .tokens.doc cimport Doc from .attrs cimport TAG from .parts_of_speech cimport NO_TAG, ADJ, ADV, ADP, CONJ, DET, NOUN, NUM, PRON from .parts_of_speech cimport VERB, X, PUNCT, EOL, SPACE +from .gold cimport GoldParse from .attrs cimport * @@ -134,6 +135,7 @@ cdef class Tagger: for tag in self.tag_names: self.freqs[TAG][self.vocab.strings[tag]] = 1 self.freqs[TAG][0] = 1 + self.cfg = cfg @property def tag_names(self): @@ -180,11 +182,8 @@ cdef class Tagger: self(doc) yield doc - def update(self, Doc tokens, object gold): - if hasattr(gold, 'tags'): - gold_tag_strs = list(gold.tags) - else: - gold_tag_strs = gold + def update(self, Doc tokens, GoldParse gold): + gold_tag_strs = gold.tags assert len(tokens) == len(gold_tag_strs) for tag in gold_tag_strs: if tag != None and tag not in self.tag_names: