mirror of https://github.com/explosion/spaCy.git
Refactor conllu script, fix interface, generalize
This commit is contained in:
parent
551c93fe01
commit
9e960d24fc
|
@ -13,7 +13,7 @@ import json
|
|||
import spacy
|
||||
import spacy.util
|
||||
from spacy.tokens import Token, Doc
|
||||
from spacy.gold import GoldParse, minibatch
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.syntax.nonproj import projectivize
|
||||
from collections import defaultdict, Counter
|
||||
from timeit import default_timer as timer
|
||||
|
@ -24,7 +24,7 @@ import random
|
|||
import numpy.random
|
||||
import cytoolz
|
||||
|
||||
from spacy._align import align
|
||||
import conll17_ud_eval
|
||||
|
||||
random.seed(0)
|
||||
numpy.random.seed(0)
|
||||
|
@ -43,7 +43,8 @@ def minibatch_by_words(items, size=5000):
|
|||
try:
|
||||
doc, gold = next(items)
|
||||
except StopIteration:
|
||||
yield batch
|
||||
if batch:
|
||||
yield batch
|
||||
return
|
||||
batch_size -= len(doc)
|
||||
batch.append((doc, gold))
|
||||
|
@ -56,9 +57,9 @@ def minibatch_by_words(items, size=5000):
|
|||
# Data reading #
|
||||
################
|
||||
|
||||
space_re = re.compile('\s+')
|
||||
def split_text(text):
|
||||
return [par.strip().replace('\n', ' ')
|
||||
for par in text.split('\n\n')]
|
||||
return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')]
|
||||
|
||||
|
||||
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||
|
@ -132,7 +133,10 @@ def read_conllu(file_):
|
|||
doc.append(sent)
|
||||
sent = []
|
||||
else:
|
||||
sent.append(line.strip().split())
|
||||
sent.append(list(line.strip().split('\t')))
|
||||
if len(sent[-1]) != 10:
|
||||
print(repr(line))
|
||||
raise ValueError
|
||||
if sent:
|
||||
doc.append(sent)
|
||||
if doc:
|
||||
|
@ -176,50 +180,21 @@ def golds_to_gold_tuples(docs, golds):
|
|||
# Evaluation #
|
||||
##############
|
||||
|
||||
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
||||
joint_sbd=True, limit=None):
|
||||
with open(text_loc) as text_file:
|
||||
with open(conllu_loc) as conllu_file:
|
||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||
oracle_segments=oracle_segments, limit=limit)
|
||||
if joint_sbd:
|
||||
pass
|
||||
else:
|
||||
sbd = nlp.create_pipe('sentencizer')
|
||||
for doc in docs:
|
||||
doc = sbd(doc)
|
||||
for sent in doc.sents:
|
||||
sent[0].is_sent_start = True
|
||||
for word in sent[1:]:
|
||||
word.is_sent_start = False
|
||||
scorer = nlp.evaluate(zip(docs, golds))
|
||||
return docs, scorer
|
||||
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
||||
with text_loc.open('r', encoding='utf8') as text_file:
|
||||
texts = split_text(text_file.read())
|
||||
docs = list(nlp.pipe(texts))
|
||||
with sys_loc.open('w', encoding='utf8') as out_file:
|
||||
write_conllu(docs, out_file)
|
||||
with gold_loc.open('r', encoding='utf8') as gold_file:
|
||||
gold_ud = conll17_ud_eval.load_conllu(gold_file)
|
||||
with sys_loc.open('r', encoding='utf8') as sys_file:
|
||||
sys_ud = conll17_ud_eval.load_conllu(sys_file)
|
||||
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
|
||||
return scores
|
||||
|
||||
|
||||
def print_progress(itn, losses, scorer):
|
||||
scores = {}
|
||||
for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc',
|
||||
'ents_p', 'ents_r', 'ents_f', 'cpu_wps', 'gpu_wps']:
|
||||
scores[col] = 0.0
|
||||
scores['dep_loss'] = losses.get('parser', 0.0)
|
||||
scores['ner_loss'] = losses.get('ner', 0.0)
|
||||
scores['tag_loss'] = losses.get('tagger', 0.0)
|
||||
scores.update(scorer.scores)
|
||||
tpl = '\t'.join((
|
||||
'{:d}',
|
||||
'{dep_loss:.3f}',
|
||||
'{ner_loss:.3f}',
|
||||
'{uas:.3f}',
|
||||
'{ents_p:.3f}',
|
||||
'{ents_r:.3f}',
|
||||
'{ents_f:.3f}',
|
||||
'{tags_acc:.3f}',
|
||||
'{token_acc:.3f}',
|
||||
))
|
||||
print(tpl.format(itn, **scores))
|
||||
|
||||
|
||||
def print_conllu(docs, file_):
|
||||
def write_conllu(docs, file_):
|
||||
merger = Matcher(docs[0].vocab)
|
||||
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
|
||||
for i, doc in enumerate(docs):
|
||||
|
@ -236,6 +211,31 @@ def print_conllu(docs, file_):
|
|||
file_.write(token._.get_conllu_lines(k) + '\n')
|
||||
file_.write('\n')
|
||||
|
||||
|
||||
def print_progress(itn, losses, ud_scores):
|
||||
fields = {
|
||||
'dep_loss': losses.get('parser', 0.0),
|
||||
'tag_loss': losses.get('tagger', 0.0),
|
||||
'words': ud_scores['Words'].f1 * 100,
|
||||
'sents': ud_scores['Sentences'].f1 * 100,
|
||||
'tags': ud_scores['XPOS'].f1 * 100,
|
||||
'uas': ud_scores['UAS'].f1 * 100,
|
||||
'las': ud_scores['LAS'].f1 * 100,
|
||||
}
|
||||
header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD']
|
||||
if itn == 0:
|
||||
print('\t'.join(header))
|
||||
tpl = '\t'.join((
|
||||
'{:d}',
|
||||
'{dep_loss:.1f}',
|
||||
'{las:.1f}',
|
||||
'{uas:.1f}',
|
||||
'{tags:.1f}',
|
||||
'{sents:.1f}',
|
||||
'{words:.1f}',
|
||||
))
|
||||
print(tpl.format(itn, **fields))
|
||||
|
||||
#def get_sent_conllu(sent, sent_id):
|
||||
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
|
||||
|
||||
|
@ -275,7 +275,6 @@ def load_nlp(corpus, config):
|
|||
return nlp
|
||||
|
||||
def initialize_pipeline(nlp, docs, golds, config):
|
||||
print("Create parser")
|
||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||
if config.multitask_tag:
|
||||
nlp.parser.add_multitask_objective('tag')
|
||||
|
@ -347,14 +346,16 @@ class TreebankPaths(object):
|
|||
|
||||
@plac.annotations(
|
||||
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
|
||||
config=("Path to json formatted config file", "positional", None, Config.load),
|
||||
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
|
||||
"positional", None, str),
|
||||
parses_loc=("Path to write the development parses", "positional", None, Path),
|
||||
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
||||
config=("Path to json formatted config file", "positional", None, Config.load),
|
||||
limit=("Size limit", "option", "n", int)
|
||||
)
|
||||
def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
|
||||
def main(ud_dir, parses_dir, config, corpus, limit=0):
|
||||
paths = TreebankPaths(ud_dir, corpus)
|
||||
if not (parses_dir / corpus).exists():
|
||||
(parses_dir / corpus).mkdir()
|
||||
print("Train and evaluate", corpus, "using lang", paths.lang)
|
||||
nlp = load_nlp(paths.lang, config)
|
||||
|
||||
|
@ -362,6 +363,7 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
|
|||
max_doc_length=config.max_doc_length, limit=limit)
|
||||
|
||||
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
||||
|
||||
for i in range(config.nr_epoch):
|
||||
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
||||
|
@ -374,11 +376,10 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
|
|||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||
drop=config.dropout, losses=losses)
|
||||
|
||||
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
|
||||
with nlp.use_params(optimizer.averages):
|
||||
dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu)
|
||||
print_progress(i, losses, scorer)
|
||||
with open(parses_loc, 'w') as file_:
|
||||
print_conllu(dev_docs, file_)
|
||||
scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
|
||||
print_progress(i, losses, scores)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue