mirror of https://github.com/explosion/spaCy.git
Refactor CoNLL training script
This commit is contained in:
parent
6a27a4f77c
commit
001e2ec6d6
|
@ -5,8 +5,10 @@ from __future__ import unicode_literals
|
|||
import plac
|
||||
import tqdm
|
||||
import re
|
||||
import sys
|
||||
import spacy
|
||||
import spacy.util
|
||||
from spacy.tokens import Doc
|
||||
from spacy.gold import GoldParse, minibatch
|
||||
from spacy.syntax.nonproj import projectivize
|
||||
from collections import Counter
|
||||
|
@ -78,16 +80,81 @@ def golds_to_gold_tuples(docs, golds):
|
|||
tuples.append((text, sents))
|
||||
return tuples
|
||||
|
||||
|
||||
def split_text(text):
|
||||
paragraphs = text.split('\n\n')
|
||||
paragraphs = [par.strip().replace('\n', ' ') for par in paragraphs]
|
||||
return paragraphs
|
||||
return [par.strip().replace('\n', ' ')
|
||||
for par in text.split('\n\n')]
|
||||
|
||||
|
||||
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||
limit=None):
|
||||
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
|
||||
include Doc objects created using nlp.make_doc and then aligned against
|
||||
the gold-standard sequences. If oracle_segments=True, include Doc objects
|
||||
created from the gold-standard segments. At least one must be True.'''
|
||||
if not raw_text and not oracle_segments:
|
||||
raise ValueError("At least one of raw_text or oracle_segments must be True")
|
||||
paragraphs = split_text(text_file.read())
|
||||
conllu = read_conllu(conllu_file)
|
||||
# sd is spacy doc; cd is conllu doc
|
||||
# cs is conllu sent, ct is conllu token
|
||||
docs = []
|
||||
golds = []
|
||||
for text, cd in zip(paragraphs, conllu):
|
||||
doc_words = []
|
||||
doc_tags = []
|
||||
doc_heads = []
|
||||
doc_deps = []
|
||||
doc_ents = []
|
||||
for cs in cd:
|
||||
sent_words = []
|
||||
sent_tags = []
|
||||
sent_heads = []
|
||||
sent_deps = []
|
||||
for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs:
|
||||
if '.' in id_:
|
||||
continue
|
||||
if '-' in id_:
|
||||
continue
|
||||
id_ = int(id_)-1
|
||||
head = int(head)-1 if head != '0' else id_
|
||||
sent_words.append(word)
|
||||
sent_tags.append(tag)
|
||||
sent_heads.append(head)
|
||||
sent_deps.append('ROOT' if dep == 'root' else dep)
|
||||
if oracle_segments:
|
||||
sent_heads, sent_deps = projectivize(sent_heads, sent_deps)
|
||||
docs.append(Doc(nlp.vocab, words=sent_words))
|
||||
golds.append(GoldParse(docs[-1], words=sent_words, heads=sent_heads,
|
||||
tags=sent_tags, deps=sent_deps,
|
||||
entities=['-']*len(sent_words)))
|
||||
for head in sent_heads:
|
||||
doc_heads.append(len(doc_words)+head)
|
||||
doc_words.extend(sent_words)
|
||||
doc_tags.extend(sent_tags)
|
||||
doc_deps.extend(sent_deps)
|
||||
doc_ents.extend(['-']*len(sent_words))
|
||||
# Create a GoldParse object for the sentence
|
||||
doc_heads, doc_deps = projectivize(doc_heads, doc_deps)
|
||||
if raw_text:
|
||||
docs.append(nlp.make_doc(text))
|
||||
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags,
|
||||
heads=doc_heads, deps=doc_deps,
|
||||
entities=doc_ents))
|
||||
if limit and len(docs) >= limit:
|
||||
break
|
||||
return docs, golds
|
||||
|
||||
|
||||
def refresh_docs(docs):
|
||||
vocab = docs[0].vocab
|
||||
return [Doc(vocab, words=[t.text for t in doc],
|
||||
spaces=[t.whitespace_ for t in doc])
|
||||
for doc in docs]
|
||||
|
||||
|
||||
def read_conllu(file_):
|
||||
docs = []
|
||||
doc = []
|
||||
doc = None
|
||||
sent = []
|
||||
for line in file_:
|
||||
if line.startswith('# newdoc'):
|
||||
|
@ -98,57 +165,37 @@ def read_conllu(file_):
|
|||
continue
|
||||
elif not line.strip():
|
||||
if sent:
|
||||
if doc is None:
|
||||
docs.append([sent])
|
||||
else:
|
||||
doc.append(sent)
|
||||
sent = []
|
||||
else:
|
||||
sent.append(line.strip().split())
|
||||
if sent:
|
||||
if doc is None:
|
||||
docs.append([sent])
|
||||
else:
|
||||
doc.append(sent)
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
def get_docs(nlp, text):
|
||||
paragraphs = split_text(text)
|
||||
docs = [nlp.make_doc(par) for par in paragraphs]
|
||||
return docs
|
||||
|
||||
|
||||
def get_golds(docs, conllu):
|
||||
# sd is spacy doc; cd is conllu doc
|
||||
# cs is conllu sent, ct is conllu token
|
||||
golds = []
|
||||
for sd, cd in zip(docs, conllu):
|
||||
words = []
|
||||
tags = []
|
||||
heads = []
|
||||
deps = []
|
||||
for cs in cd:
|
||||
for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs:
|
||||
if '.' in id_:
|
||||
continue
|
||||
i = len(words)
|
||||
id_ = int(id_)-1
|
||||
head = int(head)-1 if head != '0' else id_
|
||||
head_dist = head - id_
|
||||
words.append(word)
|
||||
tags.append(tag)
|
||||
heads.append(i+head_dist)
|
||||
deps.append('ROOT' if dep == 'root' else dep)
|
||||
heads, deps = projectivize(heads, deps)
|
||||
entities = ['-'] * len(words)
|
||||
gold = GoldParse(sd, words=words, tags=tags, heads=heads, deps=deps,
|
||||
entities=entities)
|
||||
golds.append(gold)
|
||||
return golds
|
||||
|
||||
def parse_dev_data(nlp, text_loc, conllu_loc):
|
||||
with open(text_loc) as file_:
|
||||
docs = get_docs(nlp, file_.read())
|
||||
with open(conllu_loc) as file_:
|
||||
conllu_dev = read_conllu(file_)
|
||||
golds = list(get_golds(docs, conllu_dev))
|
||||
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
||||
joint_sbd=True):
|
||||
with open(text_loc) as text_file:
|
||||
with open(conllu_loc) as conllu_file:
|
||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||
oracle_segments=oracle_segments)
|
||||
if not joint_sbd:
|
||||
sbd = nlp.create_pipe('sentencizer')
|
||||
for doc in docs:
|
||||
doc = sbd(doc)
|
||||
for sent in doc.sents:
|
||||
sent[0].is_sent_start = True
|
||||
for word in sent[1:]:
|
||||
word.is_sent_start = False
|
||||
scorer = nlp.evaluate(zip(docs, golds))
|
||||
return docs, scorer
|
||||
|
||||
|
@ -186,20 +233,19 @@ def print_conllu(docs, file_):
|
|||
head = 0
|
||||
else:
|
||||
head = k + (t.head.i - t.i) + 1
|
||||
fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_', str(head), t.dep_, '_', '_']
|
||||
fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_',
|
||||
str(head), t.dep_.lower(), '_', '_']
|
||||
file_.write('\t'.join(fields) + '\n')
|
||||
file_.write('\n')
|
||||
|
||||
|
||||
def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||
output_loc):
|
||||
with open(conllu_train_loc) as file_:
|
||||
conllu_train = read_conllu(file_)
|
||||
nlp = load_model(spacy_model)
|
||||
print("Get docs")
|
||||
with open(text_train_loc) as file_:
|
||||
docs = get_docs(nlp, file_.read())
|
||||
golds = list(get_golds(docs, conllu_train))
|
||||
with open(conllu_train_loc) as conllu_file:
|
||||
with open(text_train_loc) as text_file:
|
||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||
oracle_segments=False, raw_text=True)
|
||||
print("Create parser")
|
||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||
|
@ -221,15 +267,14 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
|||
# Batch size starts at 1 and grows, so that we make updates quickly
|
||||
# at the beginning of training.
|
||||
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
|
||||
spacy.util.env_opt('batch_to', 8),
|
||||
spacy.util.env_opt('batch_to', 2),
|
||||
spacy.util.env_opt('batch_compound', 1.001))
|
||||
for i in range(10):
|
||||
with open(text_train_loc) as file_:
|
||||
docs = get_docs(nlp, file_.read())
|
||||
docs = docs[:len(golds)]
|
||||
for i in range(30):
|
||||
docs = refresh_docs(docs)
|
||||
batches = minibatch(list(zip(docs, golds)), size=batch_sizes)
|
||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
losses = {}
|
||||
for batch in minibatch(list(zip(docs, golds)), size=batch_sizes):
|
||||
for batch in batches:
|
||||
if not batch:
|
||||
continue
|
||||
batch_docs, batch_gold = zip(*batch)
|
||||
|
@ -239,10 +284,18 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
|
|||
pbar.update(sum(len(doc) for doc in batch_docs))
|
||||
|
||||
with nlp.use_params(optimizer.averages):
|
||||
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc)
|
||||
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
|
||||
oracle_segments=False, joint_sbd=True)
|
||||
print_progress(i, losses, scorer)
|
||||
with open(output_loc, 'w') as file_:
|
||||
print_conllu(dev_docs, file_)
|
||||
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
|
||||
oracle_segments=False, joint_sbd=False)
|
||||
print_progress(i, losses, scorer)
|
||||
with open(output_loc, 'w') as file_:
|
||||
print_conllu(dev_docs, file_)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
||||
|
|
Loading…
Reference in New Issue