diff --git a/bin/parser/train.py b/bin/parser/train.py index d63106333..1c410d737 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -138,8 +138,8 @@ def write_parses(Language, dev_loc, model_dir, out_loc): @plac.annotations( - train_loc=("Location of training json file"), - dev_loc=("Location of development json file"), + train_loc=("Location of training file or directory"), + dev_loc=("Location of development file or directory"), corruption_level=("Amount of noise to add to training data", "option", "c", float), model_dir=("Location of output model directory",), out_loc=("Out location", "option", "o", str), diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 0bc2d1f72..d29ae1f35 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -4,6 +4,8 @@ import json import ijson import random import re +import os +from os import path from spacy.munge.read_ner import tags_to_entities from libc.string cimport memset @@ -94,28 +96,32 @@ def _min_edit_path(cand_words, gold_words): def read_json_file(loc): - with open(loc) as file_: - for doc in ijson.items(file_, 'item'): - paragraphs = [] - for paragraph in doc['paragraphs']: - words = [] - ids = [] - tags = [] - heads = [] - labels = [] - ner = [] - for token in paragraph['tokens']: - words.append(token['orth']) - ids.append(token['id']) - tags.append(token['tag']) - heads.append(token['head'] if token['head'] >= 0 else token['id']) - labels.append(token['dep']) - ner.append(token.get('ner', '-')) + if path.isdir(loc): + for filename in os.listdir(loc): + yield from read_json_file(path.join(loc, filename)) + else: + with open(loc) as file_: + for doc in ijson.items(file_, 'item'): + paragraphs = [] + for paragraph in doc['paragraphs']: + words = [] + ids = [] + tags = [] + heads = [] + labels = [] + ner = [] + for token in paragraph['tokens']: + words.append(token['orth']) + ids.append(token['id']) + tags.append(token['tag']) + heads.append(token['head'] if token['head'] >= 0 else token['id']) + labels.append(token['dep']) + ner.append(token.get('ner', '-')) - yield ( - paragraph.get('raw', None), - (ids, words, tags, heads, labels, ner), - paragraph.get('brackets', [])) + yield ( + paragraph.get('raw', None), + (ids, words, tags, heads, labels, ner), + paragraph.get('brackets', [])) def _iob_to_biluo(tags):