Fix train.py for 1.0

This commit is contained in:
root 2016-11-25 08:55:33 -06:00
parent 271a120d30
commit 080d29e092
1 changed files with 19 additions and 10 deletions

View File

@ -14,22 +14,31 @@ class Trainer(object):
self.gold_tuples = gold_tuples self.gold_tuples = gold_tuples
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False): def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
def _epoch(): cached_golds = {}
for raw_text, paragraph_tuples in self.gold_tuples: def _epoch(indices):
for i in indices:
raw_text, paragraph_tuples = self.gold_tuples[i]
if gold_preproc: if gold_preproc:
raw_text = None raw_text = None
else: else:
paragraph_tuples = merge_sents(paragraph_tuples) paragraph_tuples = merge_sents(paragraph_tuples)
if augment_data is not None: if augment_data is None:
docs = self.make_docs(raw_text, paragraph_tuples)
if i in cached_golds:
golds = cached_golds[i]
else:
golds = self.make_golds(docs, paragraph_tuples)
else:
raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples) raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
docs = self.make_docs(raw_text, paragraph_tuples) docs = self.make_docs(raw_text, paragraph_tuples)
golds = self.make_golds(docs, paragraph_tuples) golds = self.make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
yield doc, gold yield doc, gold
indices = list(range(len(self.gold_tuples)))
for itn in range(nr_epoch): for itn in range(nr_epoch):
random.shuffle(self.gold_tuples) random.shuffle(indices)
yield _epoch() yield _epoch(indices)
def update(self, doc, gold): def update(self, doc, gold):
for process in self.nlp.pipeline: for process in self.nlp.pipeline:
@ -48,7 +57,7 @@ class Trainer(object):
docs = self.make_docs(raw_text, paragraph_tuples) docs = self.make_docs(raw_text, paragraph_tuples)
golds = self.make_golds(docs, paragraph_tuples) golds = self.make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
for process in self.nlp.pipeline[1:]: for process in self.nlp.pipeline:
process(doc) process(doc)
scorer.score(doc, gold) scorer.score(doc, gold)
return scorer return scorer
@ -62,8 +71,8 @@ class Trainer(object):
def make_golds(self, docs, paragraph_tuples): def make_golds(self, docs, paragraph_tuples):
if len(docs) == 1: if len(docs) == 1:
return [GoldParse(docs[0], sent_tuples[0]) return [GoldParse.from_annot_tuples(docs[0], sent_tuples[0])
for sent_tuples in paragraph_tuples] for sent_tuples in paragraph_tuples]
else: else:
return [GoldParse(doc, sent_tuples[0]) return [GoldParse.from_annot_tuples(doc, sent_tuples[0])
for doc, sent_tuples in zip(docs, paragraph_tuples)] for doc, sent_tuples in zip(docs, paragraph_tuples)]