mirror of https://github.com/explosion/spaCy.git
Use GoldParse in tagger.update
This commit is contained in:
parent
59038f7efa
commit
517f090cbf
|
@ -15,6 +15,7 @@ from .tokens.doc cimport Doc
|
|||
from .attrs cimport TAG
|
||||
from .parts_of_speech cimport NO_TAG, ADJ, ADV, ADP, CONJ, DET, NOUN, NUM, PRON
|
||||
from .parts_of_speech cimport VERB, X, PUNCT, EOL, SPACE
|
||||
from .gold cimport GoldParse
|
||||
|
||||
from .attrs cimport *
|
||||
|
||||
|
@ -134,6 +135,7 @@ cdef class Tagger:
|
|||
for tag in self.tag_names:
|
||||
self.freqs[TAG][self.vocab.strings[tag]] = 1
|
||||
self.freqs[TAG][0] = 1
|
||||
self.cfg = cfg
|
||||
|
||||
@property
|
||||
def tag_names(self):
|
||||
|
@ -180,11 +182,8 @@ cdef class Tagger:
|
|||
self(doc)
|
||||
yield doc
|
||||
|
||||
def update(self, Doc tokens, object gold):
|
||||
if hasattr(gold, 'tags'):
|
||||
gold_tag_strs = list(gold.tags)
|
||||
else:
|
||||
gold_tag_strs = gold
|
||||
def update(self, Doc tokens, GoldParse gold):
|
||||
gold_tag_strs = gold.tags
|
||||
assert len(tokens) == len(gold_tag_strs)
|
||||
for tag in gold_tag_strs:
|
||||
if tag != None and tag not in self.tag_names:
|
||||
|
|
Loading…
Reference in New Issue