Refactor CoNLL training script

This commit is contained in:
Matthew Honnibal 2018-02-22 16:00:34 +01:00
parent 6a27a4f77c
commit 001e2ec6d6
1 changed files with 114 additions and 61 deletions

View File

@ -5,8 +5,10 @@ from __future__ import unicode_literals
import plac import plac
import tqdm import tqdm
import re import re
import sys
import spacy import spacy
import spacy.util import spacy.util
from spacy.tokens import Doc
from spacy.gold import GoldParse, minibatch from spacy.gold import GoldParse, minibatch
from spacy.syntax.nonproj import projectivize from spacy.syntax.nonproj import projectivize
from collections import Counter from collections import Counter
@ -78,16 +80,81 @@ def golds_to_gold_tuples(docs, golds):
tuples.append((text, sents)) tuples.append((text, sents))
return tuples return tuples
def split_text(text): def split_text(text):
paragraphs = text.split('\n\n') return [par.strip().replace('\n', ' ')
paragraphs = [par.strip().replace('\n', ' ') for par in paragraphs] for par in text.split('\n\n')]
return paragraphs
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
limit=None):
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
include Doc objects created using nlp.make_doc and then aligned against
the gold-standard sequences. If oracle_segments=True, include Doc objects
created from the gold-standard segments. At least one must be True.'''
if not raw_text and not oracle_segments:
raise ValueError("At least one of raw_text or oracle_segments must be True")
paragraphs = split_text(text_file.read())
conllu = read_conllu(conllu_file)
# sd is spacy doc; cd is conllu doc
# cs is conllu sent, ct is conllu token
docs = []
golds = []
for text, cd in zip(paragraphs, conllu):
doc_words = []
doc_tags = []
doc_heads = []
doc_deps = []
doc_ents = []
for cs in cd:
sent_words = []
sent_tags = []
sent_heads = []
sent_deps = []
for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs:
if '.' in id_:
continue
if '-' in id_:
continue
id_ = int(id_)-1
head = int(head)-1 if head != '0' else id_
sent_words.append(word)
sent_tags.append(tag)
sent_heads.append(head)
sent_deps.append('ROOT' if dep == 'root' else dep)
if oracle_segments:
sent_heads, sent_deps = projectivize(sent_heads, sent_deps)
docs.append(Doc(nlp.vocab, words=sent_words))
golds.append(GoldParse(docs[-1], words=sent_words, heads=sent_heads,
tags=sent_tags, deps=sent_deps,
entities=['-']*len(sent_words)))
for head in sent_heads:
doc_heads.append(len(doc_words)+head)
doc_words.extend(sent_words)
doc_tags.extend(sent_tags)
doc_deps.extend(sent_deps)
doc_ents.extend(['-']*len(sent_words))
# Create a GoldParse object for the sentence
doc_heads, doc_deps = projectivize(doc_heads, doc_deps)
if raw_text:
docs.append(nlp.make_doc(text))
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags,
heads=doc_heads, deps=doc_deps,
entities=doc_ents))
if limit and len(docs) >= limit:
break
return docs, golds
def refresh_docs(docs):
vocab = docs[0].vocab
return [Doc(vocab, words=[t.text for t in doc],
spaces=[t.whitespace_ for t in doc])
for doc in docs]
def read_conllu(file_): def read_conllu(file_):
docs = [] docs = []
doc = [] doc = None
sent = [] sent = []
for line in file_: for line in file_:
if line.startswith('# newdoc'): if line.startswith('# newdoc'):
@ -98,57 +165,37 @@ def read_conllu(file_):
continue continue
elif not line.strip(): elif not line.strip():
if sent: if sent:
doc.append(sent) if doc is None:
docs.append([sent])
else:
doc.append(sent)
sent = [] sent = []
else: else:
sent.append(line.strip().split()) sent.append(line.strip().split())
if sent: if sent:
doc.append(sent) if doc is None:
docs.append([sent])
else:
doc.append(sent)
if doc: if doc:
docs.append(doc) docs.append(doc)
return docs return docs
def get_docs(nlp, text): def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
paragraphs = split_text(text) joint_sbd=True):
docs = [nlp.make_doc(par) for par in paragraphs] with open(text_loc) as text_file:
return docs with open(conllu_loc) as conllu_file:
docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=oracle_segments)
def get_golds(docs, conllu): if not joint_sbd:
# sd is spacy doc; cd is conllu doc sbd = nlp.create_pipe('sentencizer')
# cs is conllu sent, ct is conllu token for doc in docs:
golds = [] doc = sbd(doc)
for sd, cd in zip(docs, conllu): for sent in doc.sents:
words = [] sent[0].is_sent_start = True
tags = [] for word in sent[1:]:
heads = [] word.is_sent_start = False
deps = []
for cs in cd:
for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs:
if '.' in id_:
continue
i = len(words)
id_ = int(id_)-1
head = int(head)-1 if head != '0' else id_
head_dist = head - id_
words.append(word)
tags.append(tag)
heads.append(i+head_dist)
deps.append('ROOT' if dep == 'root' else dep)
heads, deps = projectivize(heads, deps)
entities = ['-'] * len(words)
gold = GoldParse(sd, words=words, tags=tags, heads=heads, deps=deps,
entities=entities)
golds.append(gold)
return golds
def parse_dev_data(nlp, text_loc, conllu_loc):
with open(text_loc) as file_:
docs = get_docs(nlp, file_.read())
with open(conllu_loc) as file_:
conllu_dev = read_conllu(file_)
golds = list(get_golds(docs, conllu_dev))
scorer = nlp.evaluate(zip(docs, golds)) scorer = nlp.evaluate(zip(docs, golds))
return docs, scorer return docs, scorer
@ -186,20 +233,19 @@ def print_conllu(docs, file_):
head = 0 head = 0
else: else:
head = k + (t.head.i - t.i) + 1 head = k + (t.head.i - t.i) + 1
fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_', str(head), t.dep_, '_', '_'] fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_',
str(head), t.dep_.lower(), '_', '_']
file_.write('\t'.join(fields) + '\n') file_.write('\t'.join(fields) + '\n')
file_.write('\n') file_.write('\n')
def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
output_loc): output_loc):
with open(conllu_train_loc) as file_:
conllu_train = read_conllu(file_)
nlp = load_model(spacy_model) nlp = load_model(spacy_model)
print("Get docs") with open(conllu_train_loc) as conllu_file:
with open(text_train_loc) as file_: with open(text_train_loc) as text_file:
docs = get_docs(nlp, file_.read()) docs, golds = read_data(nlp, conllu_file, text_file,
golds = list(get_golds(docs, conllu_train)) oracle_segments=False, raw_text=True)
print("Create parser") print("Create parser")
nlp.add_pipe(nlp.create_pipe('parser')) nlp.add_pipe(nlp.create_pipe('parser'))
nlp.add_pipe(nlp.create_pipe('tagger')) nlp.add_pipe(nlp.create_pipe('tagger'))
@ -221,15 +267,14 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
# Batch size starts at 1 and grows, so that we make updates quickly # Batch size starts at 1 and grows, so that we make updates quickly
# at the beginning of training. # at the beginning of training.
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1), batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
spacy.util.env_opt('batch_to', 8), spacy.util.env_opt('batch_to', 2),
spacy.util.env_opt('batch_compound', 1.001)) spacy.util.env_opt('batch_compound', 1.001))
for i in range(10): for i in range(30):
with open(text_train_loc) as file_: docs = refresh_docs(docs)
docs = get_docs(nlp, file_.read()) batches = minibatch(list(zip(docs, golds)), size=batch_sizes)
docs = docs[:len(golds)]
with tqdm.tqdm(total=n_train_words, leave=False) as pbar: with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
losses = {} losses = {}
for batch in minibatch(list(zip(docs, golds)), size=batch_sizes): for batch in batches:
if not batch: if not batch:
continue continue
batch_docs, batch_gold = zip(*batch) batch_docs, batch_gold = zip(*batch)
@ -239,10 +284,18 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
pbar.update(sum(len(doc) for doc in batch_docs)) pbar.update(sum(len(doc) for doc in batch_docs))
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc) dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
oracle_segments=False, joint_sbd=True)
print_progress(i, losses, scorer)
with open(output_loc, 'w') as file_:
print_conllu(dev_docs, file_)
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
oracle_segments=False, joint_sbd=False)
print_progress(i, losses, scorer) print_progress(i, losses, scorer)
with open(output_loc, 'w') as file_: with open(output_loc, 'w') as file_:
print_conllu(dev_docs, file_) print_conllu(dev_docs, file_)
if __name__ == '__main__': if __name__ == '__main__':
plac.call(main) plac.call(main)