Minibatch by number of tokens, support other vectors, refactor CoNLL printing

This commit is contained in:
Matthew Honnibal 2018-02-25 10:38:06 +01:00
parent dd78ef066a
commit c388833ca6
1 changed files with 49 additions and 12 deletions

View File

@ -8,21 +8,38 @@ import re
import sys
import spacy
import spacy.util
from spacy.tokens import Doc
from spacy.tokens import Token, Doc
from spacy.gold import GoldParse, minibatch
from spacy.syntax.nonproj import projectivize
from collections import defaultdict, Counter
from timeit import default_timer as timer
from spacy.matcher import Matcher
import itertools
import random
import numpy.random
import cytoolz
from spacy._align import align
random.seed(0)
numpy.random.seed(0)
def minibatch_by_words(items, size=5000):
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
batch = []
while batch_size >= 0:
doc, gold = next(items)
batch_size -= len(doc)
batch.append((doc, gold))
yield batch
def get_token_acc(docs, golds):
'''Quick function to evaluate tokenization accuracy.'''
@ -214,31 +231,51 @@ def print_conllu(docs, file_):
offsets = [(span.start_char, span.end_char) for span in spans]
for start_char, end_char in offsets:
doc.merge(start_char, end_char)
#print([t.text for t in doc])
file_.write("# newdoc id = {i}\n".format(i=i))
for j, sent in enumerate(doc.sents):
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
file_.write("# text = {text}\n".format(text=sent.text))
for k, t in enumerate(sent):
if t.head.i == t.i:
for k, token in enumerate(sent):
file_.write(token._.get_conllu_lines(k) + '\n')
file_.write('\n')
#def get_sent_conllu(sent, sent_id):
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
def get_token_conllu(token, i):
if token._.begins_fused:
n = 1
while token.nbor(n)._.inside_fused:
n += 1
id_ = '%d-%d' % (k, k+n)
lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_']
else:
lines = []
if token.head.i == token.i:
head = 0
else:
head = k + (t.head.i - t.i) + 1
fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_',
str(head), t.dep_.lower(), '_', '_']
file_.write('\t'.join(fields) + '\n')
file_.write('\n')
head = i + (token.head.i - token.i) + 1
fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_',
str(head), token.dep_.lower(), '_', '_']
lines.append('\t'.join(fields))
return '\n'.join(lines)
Token.set_extension('get_conllu_lines', method=get_token_conllu)
Token.set_extension('begins_fused', default=False)
Token.set_extension('inside_fused', default=False)
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
output_loc):
nlp = spacy.blank(lang)
if lang == 'en':
nlp = spacy.blank(lang)
vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
nlp.vocab.vectors = vec_nlp.vocab.vectors
for lex in vec_nlp.vocab:
_ = nlp.vocab[lex.orth_]
vec_nlp = None
else:
nlp = spacy.load(lang)
with open(conllu_train_loc) as conllu_file:
with open(text_train_loc) as text_file:
docs, golds = read_data(nlp, conllu_file, text_file,
@ -272,7 +309,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
spacy.util.env_opt('batch_compound', 1.001))
for i in range(30):
docs = refresh_docs(docs)
batches = minibatch(list(zip(docs, golds)), size=batch_sizes)
batches = minibatch_by_words(list(zip(docs, golds)), size=1000)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
losses = {}
for batch in batches: