Update conllu script

This commit is contained in:
Matthew Honnibal 2018-02-22 19:43:54 +01:00
parent 9c8a0f6eba
commit a26e399f84
1 changed files with 16 additions and 6 deletions

View File

@ -28,6 +28,8 @@ def prevent_bad_sentences(doc):
token.is_sent_start = False token.is_sent_start = False
elif not token.nbor(-1).is_punct: elif not token.nbor(-1).is_punct:
token.is_sent_start = False token.is_sent_start = False
elif token.nbor(-1).is_left_punct:
token.is_sent_start = False
return doc return doc
@ -99,7 +101,7 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
# cs is conllu sent, ct is conllu token # cs is conllu sent, ct is conllu token
docs = [] docs = []
golds = [] golds = []
for text, cd in zip(paragraphs, conllu): for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
doc_words = [] doc_words = []
doc_tags = [] doc_tags = []
doc_heads = [] doc_heads = []
@ -140,7 +142,7 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags, golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags,
heads=doc_heads, deps=doc_deps, heads=doc_heads, deps=doc_deps,
entities=doc_ents)) entities=doc_ents))
if limit and len(docs) >= limit: if limit and doc_id >= limit:
break break
return docs, golds return docs, golds
@ -188,7 +190,14 @@ def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
with open(conllu_loc) as conllu_file: with open(conllu_loc) as conllu_file:
docs, golds = read_data(nlp, conllu_file, text_file, docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=oracle_segments) oracle_segments=oracle_segments)
if not joint_sbd: if joint_sbd:
sbd = nlp.create_pipe('sentencizer')
for doc in docs:
doc = sbd(doc)
for sent in doc.sents:
sent[0].is_sent_start = True
#docs = (prevent_bad_sentences(doc) for doc in docs)
else:
sbd = nlp.create_pipe('sentencizer') sbd = nlp.create_pipe('sentencizer')
for doc in docs: for doc in docs:
doc = sbd(doc) doc = sbd(doc)
@ -245,7 +254,8 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
with open(conllu_train_loc) as conllu_file: with open(conllu_train_loc) as conllu_file:
with open(text_train_loc) as text_file: with open(text_train_loc) as text_file:
docs, golds = read_data(nlp, conllu_file, text_file, docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=False, raw_text=True) oracle_segments=True, raw_text=True,
limit=None)
print("Create parser") print("Create parser")
nlp.add_pipe(nlp.create_pipe('parser')) nlp.add_pipe(nlp.create_pipe('parser'))
nlp.add_pipe(nlp.create_pipe('tagger')) nlp.add_pipe(nlp.create_pipe('tagger'))
@ -266,7 +276,7 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
print("Begin training") print("Begin training")
# Batch size starts at 1 and grows, so that we make updates quickly # Batch size starts at 1 and grows, so that we make updates quickly
# at the beginning of training. # at the beginning of training.
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1), batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 2),
spacy.util.env_opt('batch_to', 2), spacy.util.env_opt('batch_to', 2),
spacy.util.env_opt('batch_compound', 1.001)) spacy.util.env_opt('batch_compound', 1.001))
for i in range(30): for i in range(30):
@ -278,6 +288,7 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
if not batch: if not batch:
continue continue
batch_docs, batch_gold = zip(*batch) batch_docs, batch_gold = zip(*batch)
batch_docs = [prevent_bad_sentences(doc) for doc in batch_docs]
nlp.update(batch_docs, batch_gold, sgd=optimizer, nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=0.2, losses=losses) drop=0.2, losses=losses)
@ -296,6 +307,5 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev
print_conllu(dev_docs, file_) print_conllu(dev_docs, file_)
if __name__ == '__main__': if __name__ == '__main__':
plac.call(main) plac.call(main)