Refactor conllu script, fix interface, generalize

This commit is contained in:
Matthew Honnibal 2018-02-25 14:54:47 +01:00
parent 551c93fe01
commit 9e960d24fc
1 changed files with 57 additions and 56 deletions

View File

@ -13,7 +13,7 @@ import json
import spacy import spacy
import spacy.util import spacy.util
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
from spacy.gold import GoldParse, minibatch from spacy.gold import GoldParse
from spacy.syntax.nonproj import projectivize from spacy.syntax.nonproj import projectivize
from collections import defaultdict, Counter from collections import defaultdict, Counter
from timeit import default_timer as timer from timeit import default_timer as timer
@ -24,7 +24,7 @@ import random
import numpy.random import numpy.random
import cytoolz import cytoolz
from spacy._align import align import conll17_ud_eval
random.seed(0) random.seed(0)
numpy.random.seed(0) numpy.random.seed(0)
@ -43,7 +43,8 @@ def minibatch_by_words(items, size=5000):
try: try:
doc, gold = next(items) doc, gold = next(items)
except StopIteration: except StopIteration:
yield batch if batch:
yield batch
return return
batch_size -= len(doc) batch_size -= len(doc)
batch.append((doc, gold)) batch.append((doc, gold))
@ -56,9 +57,9 @@ def minibatch_by_words(items, size=5000):
# Data reading # # Data reading #
################ ################
space_re = re.compile('\s+')
def split_text(text): def split_text(text):
return [par.strip().replace('\n', ' ') return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')]
for par in text.split('\n\n')]
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
@ -132,7 +133,10 @@ def read_conllu(file_):
doc.append(sent) doc.append(sent)
sent = [] sent = []
else: else:
sent.append(line.strip().split()) sent.append(list(line.strip().split('\t')))
if len(sent[-1]) != 10:
print(repr(line))
raise ValueError
if sent: if sent:
doc.append(sent) doc.append(sent)
if doc: if doc:
@ -176,50 +180,21 @@ def golds_to_gold_tuples(docs, golds):
# Evaluation # # Evaluation #
############## ##############
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
joint_sbd=True, limit=None): with text_loc.open('r', encoding='utf8') as text_file:
with open(text_loc) as text_file: texts = split_text(text_file.read())
with open(conllu_loc) as conllu_file: docs = list(nlp.pipe(texts))
docs, golds = read_data(nlp, conllu_file, text_file, with sys_loc.open('w', encoding='utf8') as out_file:
oracle_segments=oracle_segments, limit=limit) write_conllu(docs, out_file)
if joint_sbd: with gold_loc.open('r', encoding='utf8') as gold_file:
pass gold_ud = conll17_ud_eval.load_conllu(gold_file)
else: with sys_loc.open('r', encoding='utf8') as sys_file:
sbd = nlp.create_pipe('sentencizer') sys_ud = conll17_ud_eval.load_conllu(sys_file)
for doc in docs: scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
doc = sbd(doc) return scores
for sent in doc.sents:
sent[0].is_sent_start = True
for word in sent[1:]:
word.is_sent_start = False
scorer = nlp.evaluate(zip(docs, golds))
return docs, scorer
def print_progress(itn, losses, scorer): def write_conllu(docs, file_):
scores = {}
for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc',
'ents_p', 'ents_r', 'ents_f', 'cpu_wps', 'gpu_wps']:
scores[col] = 0.0
scores['dep_loss'] = losses.get('parser', 0.0)
scores['ner_loss'] = losses.get('ner', 0.0)
scores['tag_loss'] = losses.get('tagger', 0.0)
scores.update(scorer.scores)
tpl = '\t'.join((
'{:d}',
'{dep_loss:.3f}',
'{ner_loss:.3f}',
'{uas:.3f}',
'{ents_p:.3f}',
'{ents_r:.3f}',
'{ents_f:.3f}',
'{tags_acc:.3f}',
'{token_acc:.3f}',
))
print(tpl.format(itn, **scores))
def print_conllu(docs, file_):
merger = Matcher(docs[0].vocab) merger = Matcher(docs[0].vocab)
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}]) merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
@ -236,6 +211,31 @@ def print_conllu(docs, file_):
file_.write(token._.get_conllu_lines(k) + '\n') file_.write(token._.get_conllu_lines(k) + '\n')
file_.write('\n') file_.write('\n')
def print_progress(itn, losses, ud_scores):
fields = {
'dep_loss': losses.get('parser', 0.0),
'tag_loss': losses.get('tagger', 0.0),
'words': ud_scores['Words'].f1 * 100,
'sents': ud_scores['Sentences'].f1 * 100,
'tags': ud_scores['XPOS'].f1 * 100,
'uas': ud_scores['UAS'].f1 * 100,
'las': ud_scores['LAS'].f1 * 100,
}
header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD']
if itn == 0:
print('\t'.join(header))
tpl = '\t'.join((
'{:d}',
'{dep_loss:.1f}',
'{las:.1f}',
'{uas:.1f}',
'{tags:.1f}',
'{sents:.1f}',
'{words:.1f}',
))
print(tpl.format(itn, **fields))
#def get_sent_conllu(sent, sent_id): #def get_sent_conllu(sent, sent_id):
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)] # lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
@ -275,7 +275,6 @@ def load_nlp(corpus, config):
return nlp return nlp
def initialize_pipeline(nlp, docs, golds, config): def initialize_pipeline(nlp, docs, golds, config):
print("Create parser")
nlp.add_pipe(nlp.create_pipe('parser')) nlp.add_pipe(nlp.create_pipe('parser'))
if config.multitask_tag: if config.multitask_tag:
nlp.parser.add_multitask_objective('tag') nlp.parser.add_multitask_objective('tag')
@ -347,14 +346,16 @@ class TreebankPaths(object):
@plac.annotations( @plac.annotations(
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path), ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
config=("Path to json formatted config file", "positional", None, Config.load),
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc", corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
"positional", None, str), "positional", None, str),
parses_loc=("Path to write the development parses", "positional", None, Path), parses_dir=("Directory to write the development parses", "positional", None, Path),
config=("Path to json formatted config file", "positional", None, Config.load),
limit=("Size limit", "option", "n", int) limit=("Size limit", "option", "n", int)
) )
def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10): def main(ud_dir, parses_dir, config, corpus, limit=0):
paths = TreebankPaths(ud_dir, corpus) paths = TreebankPaths(ud_dir, corpus)
if not (parses_dir / corpus).exists():
(parses_dir / corpus).mkdir()
print("Train and evaluate", corpus, "using lang", paths.lang) print("Train and evaluate", corpus, "using lang", paths.lang)
nlp = load_nlp(paths.lang, config) nlp = load_nlp(paths.lang, config)
@ -362,6 +363,7 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
max_doc_length=config.max_doc_length, limit=limit) max_doc_length=config.max_doc_length, limit=limit)
optimizer = initialize_pipeline(nlp, docs, golds, config) optimizer = initialize_pipeline(nlp, docs, golds, config)
for i in range(config.nr_epoch): for i in range(config.nr_epoch):
docs = [nlp.make_doc(doc.text) for doc in docs] docs = [nlp.make_doc(doc.text) for doc in docs]
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size) batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
@ -374,11 +376,10 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
nlp.update(batch_docs, batch_gold, sgd=optimizer, nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=config.dropout, losses=losses) drop=config.dropout, losses=losses)
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu) scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
print_progress(i, losses, scorer) print_progress(i, losses, scores)
with open(parses_loc, 'w') as file_:
print_conllu(dev_docs, file_)
if __name__ == '__main__': if __name__ == '__main__':