Support oracle segmentation in ud-train CLI command

This commit is contained in:
Matthew Honnibal 2018-05-08 13:47:45 +02:00
parent c49e44349a
commit fc4dd49b77
1 changed files with 29 additions and 8 deletions

View File

@ -161,9 +161,19 @@ def golds_to_gold_tuples(docs, golds):
##############
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
with text_loc.open('r', encoding='utf8') as text_file:
texts = split_text(text_file.read())
docs = list(nlp.pipe(texts))
if text_loc.parts[-1].endswith('.conllu'):
docs = []
with text_loc.open() as file_:
for conllu_doc in read_conllu(file_):
for conllu_sent in conllu_doc:
words = [line[1] for line in conllu_sent]
docs.append(Doc(nlp.vocab, words=words))
for name, component in nlp.pipeline:
docs = list(component.pipe(docs))
else:
with text_loc.open('r', encoding='utf8') as text_file:
texts = split_text(text_file.read())
docs = list(nlp.pipe(texts))
with sys_loc.open('w', encoding='utf8') as out_file:
write_conllu(docs, out_file)
with gold_loc.open('r', encoding='utf8') as gold_file:
@ -261,12 +271,12 @@ def load_nlp(corpus, config, vectors=None):
def initialize_pipeline(nlp, docs, golds, config, device):
nlp.add_pipe(nlp.create_pipe('tagger'))
nlp.add_pipe(nlp.create_pipe('parser'))
if config.multitask_tag:
nlp.parser.add_multitask_objective('tag')
if config.multitask_sent:
nlp.parser.add_multitask_objective('sent_start')
nlp.add_pipe(nlp.create_pipe('tagger'))
for gold in golds:
for tag in gold.tags:
if tag is not None:
@ -328,10 +338,12 @@ class TreebankPaths(object):
config=("Path to json formatted config file", "positional"),
limit=("Size limit", "option", "n", int),
use_gpu=("Use GPU", "option", "g", int),
use_oracle_segments=("Use oracle segments", "flag", "G", int),
vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/",
"option", "v", Path),
)
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None):
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None,
use_oracle_segments=False):
spacy.util.fix_random_seed()
lang.zh.Chinese.Defaults.use_jieba = False
lang.ja.Japanese.Defaults.use_janome = False
@ -344,13 +356,17 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
nlp = load_nlp(paths.lang, config, vectors=vectors_dir)
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
max_doc_length=config.max_doc_length, limit=limit)
max_doc_length=None, limit=limit)
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
nlp.parser.cfg['beam_update_prob'] = 1.0
for i in range(config.nr_epoch):
docs = [nlp.make_doc(doc.text) for doc in docs]
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
max_doc_length=config.max_doc_length, limit=limit,
oracle_segments=use_oracle_segments,
raw_text=not use_oracle_segments)
Xs = list(zip(docs, golds))
random.shuffle(Xs)
batches = minibatch_by_words(Xs, size=batch_sizes)
@ -365,7 +381,12 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
with nlp.use_params(optimizer.averages):
parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
if use_oracle_segments:
parsed_docs, scores = evaluate(nlp, paths.dev.conllu,
paths.dev.conllu, out_path)
else:
parsed_docs, scores = evaluate(nlp, paths.dev.text,
paths.dev.conllu, out_path)
print_progress(i, losses, scores)
_render_parses(i, parsed_docs[:50])