2018-02-21 12:53:59 +00:00
|
|
|
'''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
|
|
|
|
.conllu format for development data, allowing the official scorer to be used.
|
|
|
|
'''
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
import plac
|
|
|
|
import tqdm
|
|
|
|
import re
|
2018-02-22 15:00:34 +00:00
|
|
|
import sys
|
2018-02-21 12:53:59 +00:00
|
|
|
import spacy
|
|
|
|
import spacy.util
|
2018-02-22 15:00:34 +00:00
|
|
|
from spacy.tokens import Doc
|
2018-02-21 12:53:59 +00:00
|
|
|
from spacy.gold import GoldParse, minibatch
|
|
|
|
from spacy.syntax.nonproj import projectivize
|
|
|
|
from collections import Counter
|
|
|
|
from timeit import default_timer as timer
|
2018-02-24 09:31:53 +00:00
|
|
|
from spacy.matcher import Matcher
|
2018-02-21 12:53:59 +00:00
|
|
|
|
2018-02-23 22:49:17 +00:00
|
|
|
import random
|
|
|
|
import numpy.random
|
|
|
|
|
2018-02-21 12:53:59 +00:00
|
|
|
from spacy._align import align
|
|
|
|
|
2018-02-23 22:49:17 +00:00
|
|
|
random.seed(0)
|
|
|
|
numpy.random.seed(0)
|
|
|
|
|
2018-02-21 12:53:59 +00:00
|
|
|
|
|
|
|
def get_token_acc(docs, golds):
|
|
|
|
'''Quick function to evaluate tokenization accuracy.'''
|
|
|
|
miss = 0
|
|
|
|
hit = 0
|
|
|
|
for doc, gold in zip(docs, golds):
|
|
|
|
for i in range(len(doc)):
|
|
|
|
token = doc[i]
|
|
|
|
align = gold.words[i]
|
|
|
|
if align == None:
|
|
|
|
miss += 1
|
|
|
|
else:
|
|
|
|
hit += 1
|
|
|
|
return miss, hit
|
|
|
|
|
|
|
|
|
|
|
|
def golds_to_gold_tuples(docs, golds):
|
|
|
|
'''Get out the annoying 'tuples' format used by begin_training, given the
|
|
|
|
GoldParse objects.'''
|
|
|
|
tuples = []
|
|
|
|
for doc, gold in zip(docs, golds):
|
|
|
|
text = doc.text
|
|
|
|
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
|
|
|
|
sents = [((ids, words, tags, heads, labels, iob), [])]
|
|
|
|
tuples.append((text, sents))
|
|
|
|
return tuples
|
|
|
|
|
|
|
|
def split_text(text):
|
2018-02-22 15:00:34 +00:00
|
|
|
return [par.strip().replace('\n', ' ')
|
|
|
|
for par in text.split('\n\n')]
|
|
|
|
|
|
|
|
|
|
|
|
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
|
|
|
limit=None):
|
|
|
|
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
|
|
|
|
include Doc objects created using nlp.make_doc and then aligned against
|
|
|
|
the gold-standard sequences. If oracle_segments=True, include Doc objects
|
|
|
|
created from the gold-standard segments. At least one must be True.'''
|
|
|
|
if not raw_text and not oracle_segments:
|
|
|
|
raise ValueError("At least one of raw_text or oracle_segments must be True")
|
|
|
|
paragraphs = split_text(text_file.read())
|
|
|
|
conllu = read_conllu(conllu_file)
|
|
|
|
# sd is spacy doc; cd is conllu doc
|
|
|
|
# cs is conllu sent, ct is conllu token
|
|
|
|
docs = []
|
|
|
|
golds = []
|
2018-02-22 18:43:54 +00:00
|
|
|
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
|
2018-02-22 15:00:34 +00:00
|
|
|
doc_words = []
|
|
|
|
doc_tags = []
|
|
|
|
doc_heads = []
|
|
|
|
doc_deps = []
|
|
|
|
doc_ents = []
|
|
|
|
for cs in cd:
|
|
|
|
sent_words = []
|
|
|
|
sent_tags = []
|
|
|
|
sent_heads = []
|
|
|
|
sent_deps = []
|
|
|
|
for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs:
|
|
|
|
if '.' in id_:
|
|
|
|
continue
|
|
|
|
if '-' in id_:
|
|
|
|
continue
|
|
|
|
id_ = int(id_)-1
|
|
|
|
head = int(head)-1 if head != '0' else id_
|
|
|
|
sent_words.append(word)
|
|
|
|
sent_tags.append(tag)
|
|
|
|
sent_heads.append(head)
|
|
|
|
sent_deps.append('ROOT' if dep == 'root' else dep)
|
|
|
|
if oracle_segments:
|
|
|
|
sent_heads, sent_deps = projectivize(sent_heads, sent_deps)
|
|
|
|
docs.append(Doc(nlp.vocab, words=sent_words))
|
|
|
|
golds.append(GoldParse(docs[-1], words=sent_words, heads=sent_heads,
|
|
|
|
tags=sent_tags, deps=sent_deps,
|
|
|
|
entities=['-']*len(sent_words)))
|
|
|
|
for head in sent_heads:
|
|
|
|
doc_heads.append(len(doc_words)+head)
|
|
|
|
doc_words.extend(sent_words)
|
|
|
|
doc_tags.extend(sent_tags)
|
|
|
|
doc_deps.extend(sent_deps)
|
|
|
|
doc_ents.extend(['-']*len(sent_words))
|
|
|
|
# Create a GoldParse object for the sentence
|
|
|
|
doc_heads, doc_deps = projectivize(doc_heads, doc_deps)
|
|
|
|
if raw_text:
|
|
|
|
docs.append(nlp.make_doc(text))
|
|
|
|
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags,
|
|
|
|
heads=doc_heads, deps=doc_deps,
|
|
|
|
entities=doc_ents))
|
2018-02-22 18:43:54 +00:00
|
|
|
if limit and doc_id >= limit:
|
2018-02-22 15:00:34 +00:00
|
|
|
break
|
|
|
|
return docs, golds
|
|
|
|
|
|
|
|
|
|
|
|
def refresh_docs(docs):
|
|
|
|
vocab = docs[0].vocab
|
|
|
|
return [Doc(vocab, words=[t.text for t in doc],
|
|
|
|
spaces=[t.whitespace_ for t in doc])
|
|
|
|
for doc in docs]
|
2018-02-21 12:53:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
def read_conllu(file_):
|
|
|
|
docs = []
|
2018-02-22 15:00:34 +00:00
|
|
|
doc = None
|
2018-02-21 12:53:59 +00:00
|
|
|
sent = []
|
|
|
|
for line in file_:
|
|
|
|
if line.startswith('# newdoc'):
|
|
|
|
if doc:
|
|
|
|
docs.append(doc)
|
|
|
|
doc = []
|
|
|
|
elif line.startswith('#'):
|
|
|
|
continue
|
|
|
|
elif not line.strip():
|
|
|
|
if sent:
|
2018-02-22 15:00:34 +00:00
|
|
|
if doc is None:
|
|
|
|
docs.append([sent])
|
|
|
|
else:
|
|
|
|
doc.append(sent)
|
2018-02-21 12:53:59 +00:00
|
|
|
sent = []
|
|
|
|
else:
|
|
|
|
sent.append(line.strip().split())
|
|
|
|
if sent:
|
2018-02-22 15:00:34 +00:00
|
|
|
if doc is None:
|
|
|
|
docs.append([sent])
|
|
|
|
else:
|
|
|
|
doc.append(sent)
|
2018-02-21 12:53:59 +00:00
|
|
|
if doc:
|
|
|
|
docs.append(doc)
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
2018-02-22 15:00:34 +00:00
|
|
|
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
|
|
|
joint_sbd=True):
|
|
|
|
with open(text_loc) as text_file:
|
|
|
|
with open(conllu_loc) as conllu_file:
|
|
|
|
docs, golds = read_data(nlp, conllu_file, text_file,
|
|
|
|
oracle_segments=oracle_segments)
|
2018-02-22 18:43:54 +00:00
|
|
|
if joint_sbd:
|
2018-02-22 20:35:50 +00:00
|
|
|
pass
|
2018-02-22 18:43:54 +00:00
|
|
|
else:
|
2018-02-22 15:00:34 +00:00
|
|
|
sbd = nlp.create_pipe('sentencizer')
|
|
|
|
for doc in docs:
|
|
|
|
doc = sbd(doc)
|
|
|
|
for sent in doc.sents:
|
|
|
|
sent[0].is_sent_start = True
|
|
|
|
for word in sent[1:]:
|
|
|
|
word.is_sent_start = False
|
2018-02-21 12:53:59 +00:00
|
|
|
scorer = nlp.evaluate(zip(docs, golds))
|
|
|
|
return docs, scorer
|
|
|
|
|
|
|
|
|
2018-02-21 13:46:54 +00:00
|
|
|
def print_progress(itn, losses, scorer):
|
2018-02-21 12:53:59 +00:00
|
|
|
scores = {}
|
|
|
|
for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc',
|
|
|
|
'ents_p', 'ents_r', 'ents_f', 'cpu_wps', 'gpu_wps']:
|
|
|
|
scores[col] = 0.0
|
|
|
|
scores['dep_loss'] = losses.get('parser', 0.0)
|
|
|
|
scores['ner_loss'] = losses.get('ner', 0.0)
|
|
|
|
scores['tag_loss'] = losses.get('tagger', 0.0)
|
|
|
|
scores.update(scorer.scores)
|
|
|
|
tpl = '\t'.join((
|
|
|
|
'{:d}',
|
|
|
|
'{dep_loss:.3f}',
|
|
|
|
'{ner_loss:.3f}',
|
|
|
|
'{uas:.3f}',
|
|
|
|
'{ents_p:.3f}',
|
|
|
|
'{ents_r:.3f}',
|
|
|
|
'{ents_f:.3f}',
|
|
|
|
'{tags_acc:.3f}',
|
|
|
|
'{token_acc:.3f}',
|
|
|
|
))
|
|
|
|
print(tpl.format(itn, **scores))
|
|
|
|
|
2018-02-24 09:31:53 +00:00
|
|
|
|
2018-02-21 12:53:59 +00:00
|
|
|
def print_conllu(docs, file_):
|
2018-02-24 09:31:53 +00:00
|
|
|
merger = Matcher(docs[0].vocab)
|
|
|
|
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
|
2018-02-21 12:53:59 +00:00
|
|
|
for i, doc in enumerate(docs):
|
2018-02-24 09:31:53 +00:00
|
|
|
matches = merger(doc)
|
|
|
|
spans = [(doc[start].idx, doc[end+1].idx+len(doc[end+1]))
|
|
|
|
for (_, start, end) in matches if end < (len(doc)-1)]
|
|
|
|
for start_char, end_char in spans:
|
|
|
|
doc.merge(start_char, end_char)
|
2018-02-21 12:53:59 +00:00
|
|
|
file_.write("# newdoc id = {i}\n".format(i=i))
|
|
|
|
for j, sent in enumerate(doc.sents):
|
|
|
|
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
|
|
|
file_.write("# text = {text}\n".format(text=sent.text))
|
|
|
|
for k, t in enumerate(sent):
|
|
|
|
if t.head.i == t.i:
|
|
|
|
head = 0
|
|
|
|
else:
|
|
|
|
head = k + (t.head.i - t.i) + 1
|
2018-02-22 15:00:34 +00:00
|
|
|
fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_',
|
|
|
|
str(head), t.dep_.lower(), '_', '_']
|
2018-02-21 12:53:59 +00:00
|
|
|
file_.write('\t'.join(fields) + '\n')
|
|
|
|
file_.write('\n')
|
|
|
|
|
|
|
|
|
2018-02-24 09:31:53 +00:00
|
|
|
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
2018-02-21 12:53:59 +00:00
|
|
|
output_loc):
|
2018-02-24 09:31:53 +00:00
|
|
|
nlp = spacy.blank(lang)
|
|
|
|
if lang == 'en':
|
|
|
|
vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
|
|
|
|
nlp.vocab.vectors = vec_nlp.vocab.vectors
|
|
|
|
for lex in vec_nlp.vocab:
|
|
|
|
_ = nlp.vocab[lex.orth_]
|
|
|
|
vec_nlp = None
|
2018-02-22 15:00:34 +00:00
|
|
|
with open(conllu_train_loc) as conllu_file:
|
|
|
|
with open(text_train_loc) as text_file:
|
|
|
|
docs, golds = read_data(nlp, conllu_file, text_file,
|
2018-02-23 22:49:17 +00:00
|
|
|
oracle_segments=False, raw_text=True,
|
2018-02-22 18:43:54 +00:00
|
|
|
limit=None)
|
2018-02-21 12:53:59 +00:00
|
|
|
print("Create parser")
|
|
|
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
2018-02-23 22:49:17 +00:00
|
|
|
nlp.parser.add_multitask_objective('tag')
|
|
|
|
nlp.parser.add_multitask_objective('sent_start')
|
2018-02-24 09:31:53 +00:00
|
|
|
nlp.parser.moves.add_action(2, 'subtok')
|
2018-02-21 12:53:59 +00:00
|
|
|
nlp.add_pipe(nlp.create_pipe('tagger'))
|
|
|
|
for gold in golds:
|
|
|
|
for tag in gold.tags:
|
|
|
|
if tag is not None:
|
|
|
|
nlp.tagger.add_label(tag)
|
|
|
|
optimizer = nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
|
2018-02-21 14:59:22 +00:00
|
|
|
# Replace labels that didn't make the frequency cutoff
|
|
|
|
actions = set(nlp.parser.labels)
|
|
|
|
label_set = set([act.split('-')[1] for act in actions if '-' in act])
|
|
|
|
for gold in golds:
|
|
|
|
for i, label in enumerate(gold.labels):
|
|
|
|
if label is not None and label not in label_set:
|
|
|
|
gold.labels[i] = label.split('||')[0]
|
2018-02-21 12:53:59 +00:00
|
|
|
n_train_words = sum(len(doc) for doc in docs)
|
|
|
|
print(n_train_words)
|
|
|
|
print("Begin training")
|
2018-02-21 20:02:41 +00:00
|
|
|
# Batch size starts at 1 and grows, so that we make updates quickly
|
|
|
|
# at the beginning of training.
|
2018-02-23 22:49:17 +00:00
|
|
|
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
|
2018-02-24 09:31:53 +00:00
|
|
|
spacy.util.env_opt('batch_to', 2),
|
2018-02-21 20:02:41 +00:00
|
|
|
spacy.util.env_opt('batch_compound', 1.001))
|
2018-02-22 15:00:34 +00:00
|
|
|
for i in range(30):
|
|
|
|
docs = refresh_docs(docs)
|
|
|
|
batches = minibatch(list(zip(docs, golds)), size=batch_sizes)
|
2018-02-21 12:53:59 +00:00
|
|
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
|
|
|
losses = {}
|
2018-02-22 15:00:34 +00:00
|
|
|
for batch in batches:
|
2018-02-21 12:53:59 +00:00
|
|
|
if not batch:
|
|
|
|
continue
|
|
|
|
batch_docs, batch_gold = zip(*batch)
|
|
|
|
|
|
|
|
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
|
|
|
drop=0.2, losses=losses)
|
|
|
|
pbar.update(sum(len(doc) for doc in batch_docs))
|
|
|
|
|
|
|
|
with nlp.use_params(optimizer.averages):
|
2018-02-22 15:00:34 +00:00
|
|
|
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
|
|
|
|
oracle_segments=False, joint_sbd=True)
|
|
|
|
print_progress(i, losses, scorer)
|
|
|
|
with open(output_loc, 'w') as file_:
|
|
|
|
print_conllu(dev_docs, file_)
|
|
|
|
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
|
|
|
|
oracle_segments=False, joint_sbd=False)
|
2018-02-21 13:46:54 +00:00
|
|
|
print_progress(i, losses, scorer)
|
2018-02-21 12:53:59 +00:00
|
|
|
|
2018-02-22 15:00:34 +00:00
|
|
|
|
2018-02-21 12:53:59 +00:00
|
|
|
if __name__ == '__main__':
|
|
|
|
plac.call(main)
|