From 001e2ec6d6feff1367d173117592c894205ae0b4 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 22 Feb 2018 16:00:34 +0100 Subject: [PATCH] Refactor CoNLL training script --- examples/training/conllu.py | 175 +++++++++++++++++++++++------------- 1 file changed, 114 insertions(+), 61 deletions(-) diff --git a/examples/training/conllu.py b/examples/training/conllu.py index fa4fefcea..50716a0e1 100644 --- a/examples/training/conllu.py +++ b/examples/training/conllu.py @@ -5,8 +5,10 @@ from __future__ import unicode_literals import plac import tqdm import re +import sys import spacy import spacy.util +from spacy.tokens import Doc from spacy.gold import GoldParse, minibatch from spacy.syntax.nonproj import projectivize from collections import Counter @@ -78,16 +80,81 @@ def golds_to_gold_tuples(docs, golds): tuples.append((text, sents)) return tuples - def split_text(text): - paragraphs = text.split('\n\n') - paragraphs = [par.strip().replace('\n', ' ') for par in paragraphs] - return paragraphs + return [par.strip().replace('\n', ' ') + for par in text.split('\n\n')] + + +def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, + limit=None): + '''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True, + include Doc objects created using nlp.make_doc and then aligned against + the gold-standard sequences. If oracle_segments=True, include Doc objects + created from the gold-standard segments. At least one must be True.''' + if not raw_text and not oracle_segments: + raise ValueError("At least one of raw_text or oracle_segments must be True") + paragraphs = split_text(text_file.read()) + conllu = read_conllu(conllu_file) + # sd is spacy doc; cd is conllu doc + # cs is conllu sent, ct is conllu token + docs = [] + golds = [] + for text, cd in zip(paragraphs, conllu): + doc_words = [] + doc_tags = [] + doc_heads = [] + doc_deps = [] + doc_ents = [] + for cs in cd: + sent_words = [] + sent_tags = [] + sent_heads = [] + sent_deps = [] + for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs: + if '.' in id_: + continue + if '-' in id_: + continue + id_ = int(id_)-1 + head = int(head)-1 if head != '0' else id_ + sent_words.append(word) + sent_tags.append(tag) + sent_heads.append(head) + sent_deps.append('ROOT' if dep == 'root' else dep) + if oracle_segments: + sent_heads, sent_deps = projectivize(sent_heads, sent_deps) + docs.append(Doc(nlp.vocab, words=sent_words)) + golds.append(GoldParse(docs[-1], words=sent_words, heads=sent_heads, + tags=sent_tags, deps=sent_deps, + entities=['-']*len(sent_words))) + for head in sent_heads: + doc_heads.append(len(doc_words)+head) + doc_words.extend(sent_words) + doc_tags.extend(sent_tags) + doc_deps.extend(sent_deps) + doc_ents.extend(['-']*len(sent_words)) + # Create a GoldParse object for the sentence + doc_heads, doc_deps = projectivize(doc_heads, doc_deps) + if raw_text: + docs.append(nlp.make_doc(text)) + golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags, + heads=doc_heads, deps=doc_deps, + entities=doc_ents)) + if limit and len(docs) >= limit: + break + return docs, golds + + +def refresh_docs(docs): + vocab = docs[0].vocab + return [Doc(vocab, words=[t.text for t in doc], + spaces=[t.whitespace_ for t in doc]) + for doc in docs] def read_conllu(file_): docs = [] - doc = [] + doc = None sent = [] for line in file_: if line.startswith('# newdoc'): @@ -98,57 +165,37 @@ def read_conllu(file_): continue elif not line.strip(): if sent: - doc.append(sent) + if doc is None: + docs.append([sent]) + else: + doc.append(sent) sent = [] else: sent.append(line.strip().split()) if sent: - doc.append(sent) + if doc is None: + docs.append([sent]) + else: + doc.append(sent) if doc: docs.append(doc) return docs -def get_docs(nlp, text): - paragraphs = split_text(text) - docs = [nlp.make_doc(par) for par in paragraphs] - return docs - - -def get_golds(docs, conllu): - # sd is spacy doc; cd is conllu doc - # cs is conllu sent, ct is conllu token - golds = [] - for sd, cd in zip(docs, conllu): - words = [] - tags = [] - heads = [] - deps = [] - for cs in cd: - for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs: - if '.' in id_: - continue - i = len(words) - id_ = int(id_)-1 - head = int(head)-1 if head != '0' else id_ - head_dist = head - id_ - words.append(word) - tags.append(tag) - heads.append(i+head_dist) - deps.append('ROOT' if dep == 'root' else dep) - heads, deps = projectivize(heads, deps) - entities = ['-'] * len(words) - gold = GoldParse(sd, words=words, tags=tags, heads=heads, deps=deps, - entities=entities) - golds.append(gold) - return golds - -def parse_dev_data(nlp, text_loc, conllu_loc): - with open(text_loc) as file_: - docs = get_docs(nlp, file_.read()) - with open(conllu_loc) as file_: - conllu_dev = read_conllu(file_) - golds = list(get_golds(docs, conllu_dev)) +def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, + joint_sbd=True): + with open(text_loc) as text_file: + with open(conllu_loc) as conllu_file: + docs, golds = read_data(nlp, conllu_file, text_file, + oracle_segments=oracle_segments) + if not joint_sbd: + sbd = nlp.create_pipe('sentencizer') + for doc in docs: + doc = sbd(doc) + for sent in doc.sents: + sent[0].is_sent_start = True + for word in sent[1:]: + word.is_sent_start = False scorer = nlp.evaluate(zip(docs, golds)) return docs, scorer @@ -186,20 +233,19 @@ def print_conllu(docs, file_): head = 0 else: head = k + (t.head.i - t.i) + 1 - fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_', str(head), t.dep_, '_', '_'] + fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_', + str(head), t.dep_.lower(), '_', '_'] file_.write('\t'.join(fields) + '\n') file_.write('\n') def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, output_loc): - with open(conllu_train_loc) as file_: - conllu_train = read_conllu(file_) nlp = load_model(spacy_model) - print("Get docs") - with open(text_train_loc) as file_: - docs = get_docs(nlp, file_.read()) - golds = list(get_golds(docs, conllu_train)) + with open(conllu_train_loc) as conllu_file: + with open(text_train_loc) as text_file: + docs, golds = read_data(nlp, conllu_file, text_file, + oracle_segments=False, raw_text=True) print("Create parser") nlp.add_pipe(nlp.create_pipe('parser')) nlp.add_pipe(nlp.create_pipe('tagger')) @@ -221,15 +267,14 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev # Batch size starts at 1 and grows, so that we make updates quickly # at the beginning of training. batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1), - spacy.util.env_opt('batch_to', 8), + spacy.util.env_opt('batch_to', 2), spacy.util.env_opt('batch_compound', 1.001)) - for i in range(10): - with open(text_train_loc) as file_: - docs = get_docs(nlp, file_.read()) - docs = docs[:len(golds)] + for i in range(30): + docs = refresh_docs(docs) + batches = minibatch(list(zip(docs, golds)), size=batch_sizes) with tqdm.tqdm(total=n_train_words, leave=False) as pbar: losses = {} - for batch in minibatch(list(zip(docs, golds)), size=batch_sizes): + for batch in batches: if not batch: continue batch_docs, batch_gold = zip(*batch) @@ -239,10 +284,18 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev pbar.update(sum(len(doc) for doc in batch_docs)) with nlp.use_params(optimizer.averages): - dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc) + dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, + oracle_segments=False, joint_sbd=True) + print_progress(i, losses, scorer) + with open(output_loc, 'w') as file_: + print_conllu(dev_docs, file_) + dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, + oracle_segments=False, joint_sbd=False) print_progress(i, losses, scorer) with open(output_loc, 'w') as file_: print_conllu(dev_docs, file_) + + if __name__ == '__main__': plac.call(main)