Support sentence limits in GoldCorpus

This commit is contained in:
Matthew Honnibal 2017-05-22 10:40:46 -05:00
parent e2136232f9
commit c9760b2104
1 changed files with 19 additions and 9 deletions

View File

@ -144,7 +144,7 @@ def _min_edit_path(cand_words, gold_words):
class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER."""
def __init__(self, train_path, dev_path):
def __init__(self, train_path, dev_path, limit=None):
"""Create a GoldCorpus.
train_path (unicode or Path): File or directory of training data.
@ -152,20 +152,31 @@ class GoldCorpus(object):
"""
self.train_path = util.ensure_path(train_path)
self.dev_path = util.ensure_path(dev_path)
self.limit = limit
self.train_locs = self.walk_corpus(self.train_path)
self.dev_locs = self.walk_corpus(self.dev_path)
@property
def train_tuples(self):
i = 0
for loc in self.train_locs:
gold_tuples = read_json_file(loc)
yield from gold_tuples
for item in gold_tuples:
yield item
i += 1
if self.limit and i >= self.limit:
break
@property
def dev_tuples(self):
i = 0
for loc in self.dev_locs:
gold_tuples = read_json_file(loc)
yield from gold_tuples
for item in gold_tuples:
yield item
i += 1
if self.limit and i >= self.limit:
break
def count_train(self):
n = 0
@ -175,8 +186,7 @@ class GoldCorpus(object):
def train_docs(self, nlp, shuffle=0, gold_preproc=True,
projectivize=False):
if shuffle:
random.shuffle(self.train_locs)
train_tuples = self.train_tuples
if projectivize:
train_tuples = nonproj.preprocess_training_data(
self.train_tuples)
@ -185,13 +195,13 @@ class GoldCorpus(object):
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc)
yield from gold_docs
def dev_docs(self, nlp):
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples)
def dev_docs(self, nlp, gold_preproc=True):
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples, gold_preproc)
gold_docs = nlp.preprocess_gold(gold_docs)
yield from gold_docs
@classmethod
def iter_gold_docs(cls, nlp, tuples, gold_preproc=True):
def iter_gold_docs(cls, nlp, tuples, gold_preproc):
for raw_text, paragraph_tuples in tuples:
docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
gold_preproc)
@ -275,7 +285,7 @@ def read_json_file(loc, docs_filter=None, limit=None):
ner.append(token.get('ner', '-'))
sents.append([
[ids, words, tags, heads, labels, ner],
sent.get('brackets', [])])
sent.get('brackets', [])])
if sents:
yield [paragraph.get('raw', None), sents]