Fix train.py for 1.0

This commit is contained in:
root 2016-11-25 08:55:33 -06:00
parent 271a120d30
commit 080d29e092
1 changed files with 19 additions and 10 deletions

View File

@ -14,22 +14,31 @@ class Trainer(object):
self.gold_tuples = gold_tuples
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
def _epoch():
for raw_text, paragraph_tuples in self.gold_tuples:
cached_golds = {}
def _epoch(indices):
for i in indices:
raw_text, paragraph_tuples = self.gold_tuples[i]
if gold_preproc:
raw_text = None
else:
paragraph_tuples = merge_sents(paragraph_tuples)
if augment_data is not None:
if augment_data is None:
docs = self.make_docs(raw_text, paragraph_tuples)
if i in cached_golds:
golds = cached_golds[i]
else:
golds = self.make_golds(docs, paragraph_tuples)
else:
raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
docs = self.make_docs(raw_text, paragraph_tuples)
golds = self.make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):
yield doc, gold
indices = list(range(len(self.gold_tuples)))
for itn in range(nr_epoch):
random.shuffle(self.gold_tuples)
yield _epoch()
random.shuffle(indices)
yield _epoch(indices)
def update(self, doc, gold):
for process in self.nlp.pipeline:
@ -48,7 +57,7 @@ class Trainer(object):
docs = self.make_docs(raw_text, paragraph_tuples)
golds = self.make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):
for process in self.nlp.pipeline[1:]:
for process in self.nlp.pipeline:
process(doc)
scorer.score(doc, gold)
return scorer
@ -62,8 +71,8 @@ class Trainer(object):
def make_golds(self, docs, paragraph_tuples):
if len(docs) == 1:
return [GoldParse(docs[0], sent_tuples[0])
return [GoldParse.from_annot_tuples(docs[0], sent_tuples[0])
for sent_tuples in paragraph_tuples]
else:
return [GoldParse(doc, sent_tuples[0])
return [GoldParse.from_annot_tuples(doc, sent_tuples[0])
for doc, sent_tuples in zip(docs, paragraph_tuples)]