From 23236340f4669477376da583f42a5eeccaa58958 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 22 Feb 2018 21:35:50 +0100 Subject: [PATCH] Update CoNLL script. Don't preset SBD. Set batch size to 8, avoid writing twice --- examples/training/conllu.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/examples/training/conllu.py b/examples/training/conllu.py index 271a049c0..a2b4b2fe1 100644 --- a/examples/training/conllu.py +++ b/examples/training/conllu.py @@ -191,12 +191,7 @@ def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, docs, golds = read_data(nlp, conllu_file, text_file, oracle_segments=oracle_segments) if joint_sbd: - sbd = nlp.create_pipe('sentencizer') - for doc in docs: - doc = sbd(doc) - for sent in doc.sents: - sent[0].is_sent_start = True - #docs = (prevent_bad_sentences(doc) for doc in docs) + pass else: sbd = nlp.create_pipe('sentencizer') for doc in docs: @@ -276,8 +271,8 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev print("Begin training") # Batch size starts at 1 and grows, so that we make updates quickly # at the beginning of training. - batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 2), - spacy.util.env_opt('batch_to', 2), + batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 8), + spacy.util.env_opt('batch_to', 8), spacy.util.env_opt('batch_compound', 1.001)) for i in range(30): docs = refresh_docs(docs) @@ -288,7 +283,6 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev if not batch: continue batch_docs, batch_gold = zip(*batch) - batch_docs = [prevent_bad_sentences(doc) for doc in batch_docs] nlp.update(batch_docs, batch_gold, sgd=optimizer, drop=0.2, losses=losses) @@ -303,8 +297,6 @@ def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, oracle_segments=False, joint_sbd=False) print_progress(i, losses, scorer) - with open(output_loc, 'w') as file_: - print_conllu(dev_docs, file_) if __name__ == '__main__':