mirror of https://github.com/explosion/spaCy.git
Support oracle segmentation in ud-train CLI command
This commit is contained in:
parent
c49e44349a
commit
fc4dd49b77
|
@ -161,9 +161,19 @@ def golds_to_gold_tuples(docs, golds):
|
|||
##############
|
||||
|
||||
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
||||
with text_loc.open('r', encoding='utf8') as text_file:
|
||||
texts = split_text(text_file.read())
|
||||
docs = list(nlp.pipe(texts))
|
||||
if text_loc.parts[-1].endswith('.conllu'):
|
||||
docs = []
|
||||
with text_loc.open() as file_:
|
||||
for conllu_doc in read_conllu(file_):
|
||||
for conllu_sent in conllu_doc:
|
||||
words = [line[1] for line in conllu_sent]
|
||||
docs.append(Doc(nlp.vocab, words=words))
|
||||
for name, component in nlp.pipeline:
|
||||
docs = list(component.pipe(docs))
|
||||
else:
|
||||
with text_loc.open('r', encoding='utf8') as text_file:
|
||||
texts = split_text(text_file.read())
|
||||
docs = list(nlp.pipe(texts))
|
||||
with sys_loc.open('w', encoding='utf8') as out_file:
|
||||
write_conllu(docs, out_file)
|
||||
with gold_loc.open('r', encoding='utf8') as gold_file:
|
||||
|
@ -261,12 +271,12 @@ def load_nlp(corpus, config, vectors=None):
|
|||
|
||||
|
||||
def initialize_pipeline(nlp, docs, golds, config, device):
|
||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||
if config.multitask_tag:
|
||||
nlp.parser.add_multitask_objective('tag')
|
||||
if config.multitask_sent:
|
||||
nlp.parser.add_multitask_objective('sent_start')
|
||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||
for gold in golds:
|
||||
for tag in gold.tags:
|
||||
if tag is not None:
|
||||
|
@ -328,10 +338,12 @@ class TreebankPaths(object):
|
|||
config=("Path to json formatted config file", "positional"),
|
||||
limit=("Size limit", "option", "n", int),
|
||||
use_gpu=("Use GPU", "option", "g", int),
|
||||
use_oracle_segments=("Use oracle segments", "flag", "G", int),
|
||||
vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/",
|
||||
"option", "v", Path),
|
||||
)
|
||||
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None):
|
||||
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None,
|
||||
use_oracle_segments=False):
|
||||
spacy.util.fix_random_seed()
|
||||
lang.zh.Chinese.Defaults.use_jieba = False
|
||||
lang.ja.Japanese.Defaults.use_janome = False
|
||||
|
@ -344,13 +356,17 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
|
|||
nlp = load_nlp(paths.lang, config, vectors=vectors_dir)
|
||||
|
||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||
max_doc_length=config.max_doc_length, limit=limit)
|
||||
max_doc_length=None, limit=limit)
|
||||
|
||||
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
|
||||
|
||||
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
|
||||
nlp.parser.cfg['beam_update_prob'] = 1.0
|
||||
for i in range(config.nr_epoch):
|
||||
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||
max_doc_length=config.max_doc_length, limit=limit,
|
||||
oracle_segments=use_oracle_segments,
|
||||
raw_text=not use_oracle_segments)
|
||||
Xs = list(zip(docs, golds))
|
||||
random.shuffle(Xs)
|
||||
batches = minibatch_by_words(Xs, size=batch_sizes)
|
||||
|
@ -365,7 +381,12 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
|
|||
|
||||
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
|
||||
with nlp.use_params(optimizer.averages):
|
||||
parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
|
||||
if use_oracle_segments:
|
||||
parsed_docs, scores = evaluate(nlp, paths.dev.conllu,
|
||||
paths.dev.conllu, out_path)
|
||||
else:
|
||||
parsed_docs, scores = evaluate(nlp, paths.dev.text,
|
||||
paths.dev.conllu, out_path)
|
||||
print_progress(i, losses, scores)
|
||||
_render_parses(i, parsed_docs[:50])
|
||||
|
||||
|
|
Loading…
Reference in New Issue