diff --git a/spacy/gold.pyx b/spacy/gold.pyx index bc34290f4..651cefe2f 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -144,7 +144,7 @@ def _min_edit_path(cand_words, gold_words): class GoldCorpus(object): """An annotated corpus, using the JSON file format. Manages annotations for tagging, dependency parsing and NER.""" - def __init__(self, train_path, dev_path): + def __init__(self, train_path, dev_path, limit=None): """Create a GoldCorpus. train_path (unicode or Path): File or directory of training data. @@ -152,20 +152,31 @@ class GoldCorpus(object): """ self.train_path = util.ensure_path(train_path) self.dev_path = util.ensure_path(dev_path) + self.limit = limit self.train_locs = self.walk_corpus(self.train_path) self.dev_locs = self.walk_corpus(self.dev_path) @property def train_tuples(self): + i = 0 for loc in self.train_locs: gold_tuples = read_json_file(loc) - yield from gold_tuples + for item in gold_tuples: + yield item + i += 1 + if self.limit and i >= self.limit: + break @property def dev_tuples(self): + i = 0 for loc in self.dev_locs: gold_tuples = read_json_file(loc) - yield from gold_tuples + for item in gold_tuples: + yield item + i += 1 + if self.limit and i >= self.limit: + break def count_train(self): n = 0 @@ -175,8 +186,7 @@ class GoldCorpus(object): def train_docs(self, nlp, shuffle=0, gold_preproc=True, projectivize=False): - if shuffle: - random.shuffle(self.train_locs) + train_tuples = self.train_tuples if projectivize: train_tuples = nonproj.preprocess_training_data( self.train_tuples) @@ -185,13 +195,13 @@ class GoldCorpus(object): gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc) yield from gold_docs - def dev_docs(self, nlp): - gold_docs = self.iter_gold_docs(nlp, self.dev_tuples) + def dev_docs(self, nlp, gold_preproc=True): + gold_docs = self.iter_gold_docs(nlp, self.dev_tuples, gold_preproc) gold_docs = nlp.preprocess_gold(gold_docs) yield from gold_docs @classmethod - def iter_gold_docs(cls, nlp, tuples, gold_preproc=True): + def iter_gold_docs(cls, nlp, tuples, gold_preproc): for raw_text, paragraph_tuples in tuples: docs = cls._make_docs(nlp, raw_text, paragraph_tuples, gold_preproc) @@ -275,7 +285,7 @@ def read_json_file(loc, docs_filter=None, limit=None): ner.append(token.get('ner', '-')) sents.append([ [ids, words, tags, heads, labels, ner], - sent.get('brackets', [])]) + sent.get('brackets', [])]) if sents: yield [paragraph.get('raw', None), sents]