Support gold preprocessing and single gold files

This commit is contained in:
Matthew Honnibal 2017-05-21 17:50:49 -05:00
parent e14533757b
commit f13d6c7359
1 changed files with 20 additions and 10 deletions

View File

@ -168,33 +168,41 @@ class GoldCorpus(object):
n += 1
return n
def train_docs(self, nlp, shuffle=0):
def train_docs(self, nlp, shuffle=0, gold_preproc=True):
if shuffle:
random.shuffle(self.train_locs)
gold_docs = self.iter_gold_docs(nlp, self.train_tuples)
gold_docs = self.iter_gold_docs(nlp, self.train_tuples, gold_preproc)
if shuffle:
gold_docs = util.itershuffle(gold_docs, bufsize=shuffle*1000)
gold_docs = nlp.preprocess_gold(gold_docs)
yield from gold_docs
def dev_docs(self, nlp):
yield from self.iter_gold_docs(nlp, self.dev_tuples)
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples)
gold_docs = nlp.preprocess_gold(gold_docs)
yield from gold_docs
@classmethod
def iter_gold_docs(cls, nlp, tuples):
def iter_gold_docs(cls, nlp, tuples, gold_preproc=True):
tuples = nonproj.PseudoProjectivity.preprocess_training_data(tuples)
for raw_text, paragraph_tuples in tuples:
docs = cls._make_docs(nlp, raw_text, paragraph_tuples)
docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
gold_preproc)
golds = cls._make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):
yield doc, gold
@classmethod
def _make_docs(cls, nlp, raw_text, paragraph_tuples):
if raw_text is not None:
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc):
if gold_preproc:
return [Doc(nlp.vocab, words=sent_tuples[0][1])
for sent_tuples in paragraph_tuples]
elif raw_text is not None:
return [nlp.make_doc(raw_text)]
else:
return [
Doc(nlp.vocab, words=sent_tuples[0][1])
docs = [Doc(nlp.vocab, words=sent_tuples[0][1])
for sent_tuples in paragraph_tuples]
return merge_sents(docs)
@classmethod
def _make_golds(cls, docs, paragraph_tuples):
@ -207,8 +215,10 @@ class GoldCorpus(object):
@staticmethod
def walk_corpus(path):
locs = []
if not path.is_dir():
return [path]
paths = [path]
locs = []
seen = set()
for path in paths:
if str(path) in seen: