From a26e399f84f1688b1a2a8e6503aa44eb5c0136ea Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 22 Feb 2018 19:43:54 +0100 Subject: [PATCH] Update conllu script --- examples/training/conllu.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/training/conllu.py b/examples/training/conllu.py index 50716a0e1..271a049c0 100644 --- a/examples/training/conllu.py +++ b/examples/training/conllu.py @@ -28,6 +28,8 @@ def prevent_bad_sentences(doc): token.is_sent_start = False elif not token.nbor(-1).is_punct: token.is_sent_start = False + elif token.nbor(-1).is_left_punct: + token.is_sent_start = False return doc @@ -99,7 +101,7 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, # cs is conllu sent, ct is conllu token docs = [] golds = [] - for text, cd in zip(paragraphs, conllu): + for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)): doc_words = [] doc_tags = [] doc_heads = [] @@ -140,7 +142,7 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags, heads=doc_heads, deps=doc_deps, entities=doc_ents)) - if limit and len(docs) >= limit: + if limit and doc_id >= limit: break return docs, golds @@ -188,7 +190,14 @@ def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, with open(conllu_loc) as conllu_file: docs, golds = read_data(nlp, conllu_file, text_file, oracle_segments=oracle_segments) - if not joint_sbd: + if joint_sbd: + sbd = nlp.create_pipe('sentencizer') + for doc in docs: + doc = sbd(doc) + for sent in doc.sents: + sent[0].is_sent_start = True + #docs = (prevent_bad_sentences(doc) for doc in docs) + else: sbd = nlp.create_pipe('sentencizer') for doc in docs: doc = sbd(doc) @@ -245,7 +254,8 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev with open(conllu_train_loc) as conllu_file: with open(text_train_loc) as text_file: docs, golds = read_data(nlp, conllu_file, text_file, - oracle_segments=False, raw_text=True) + oracle_segments=True, raw_text=True, + limit=None) print("Create parser") nlp.add_pipe(nlp.create_pipe('parser')) nlp.add_pipe(nlp.create_pipe('tagger')) @@ -266,7 +276,7 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev print("Begin training") # Batch size starts at 1 and grows, so that we make updates quickly # at the beginning of training. - batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1), + batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 2), spacy.util.env_opt('batch_to', 2), spacy.util.env_opt('batch_compound', 1.001)) for i in range(30): @@ -278,6 +288,7 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev if not batch: continue batch_docs, batch_gold = zip(*batch) + batch_docs = [prevent_bad_sentences(doc) for doc in batch_docs] nlp.update(batch_docs, batch_gold, sgd=optimizer, drop=0.2, losses=losses) @@ -296,6 +307,5 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev print_conllu(dev_docs, file_) - if __name__ == '__main__': plac.call(main)