From e2ed6251d761c115b5849d404f2e33a94044eda5 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 2 Feb 2016 22:58:06 +0100 Subject: [PATCH] * Fancy up the CLI for the conll train script --- bin/parser/conll_train.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bin/parser/conll_train.py b/bin/parser/conll_train.py index da9bb807a..8075dcd8a 100755 --- a/bin/parser/conll_train.py +++ b/bin/parser/conll_train.py @@ -129,10 +129,18 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0 print('done') -def main(train_loc, dev_loc, model_dir): +@plac.annotations( + train_loc=("Location of CoNLL 09 formatted training file"), + dev_loc=("Location of CoNLL 09 formatted development file"), + model_dir=("Location of output model directory"), + eval_only=("Skip training, and only evaluate", "flag", "e", bool), + n_iter=("Number of training iterations", "option", "i", int), +) +def main(train_loc, dev_loc, model_dir, n_iter=15): with io.open(train_loc, 'r', encoding='utf8') as file_: train_sents = read_conll(file_) - #train(English, train_sents, model_dir) + if not eval_only: + train(English, train_sents, model_dir, n_iter=n_iter) nlp = English(data_dir=model_dir) dev_sents = read_conll(io.open(dev_loc, 'r', encoding='utf8')) scorer = Scorer()