mirror of https://github.com/explosion/spaCy.git
* Wire up beam-width command line argument
This commit is contained in:
parent
7c29362d60
commit
a3de20118e
|
@ -159,7 +159,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
|
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
|
||||||
|
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples))
|
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
||||||
|
beam_width=16)
|
||||||
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
||||||
labels=Language.EntityTransitionSystem.get_labels(gold_tuples))
|
labels=Language.EntityTransitionSystem.get_labels(gold_tuples))
|
||||||
|
|
||||||
|
@ -248,11 +249,12 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
||||||
out_loc=("Out location", "option", "o", str),
|
out_loc=("Out location", "option", "o", str),
|
||||||
n_sents=("Number of training sentences", "option", "n", int),
|
n_sents=("Number of training sentences", "option", "n", int),
|
||||||
n_iter=("Number of training iterations", "option", "i", int),
|
n_iter=("Number of training iterations", "option", "i", int),
|
||||||
|
beam_width=("Number of candidates to maintain in the beam", "option", "k", int),
|
||||||
verbose=("Verbose error reporting", "flag", "v", bool),
|
verbose=("Verbose error reporting", "flag", "v", bool),
|
||||||
debug=("Debug mode", "flag", "d", bool)
|
debug=("Debug mode", "flag", "d", bool)
|
||||||
)
|
)
|
||||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||||
debug=False, corruption_level=0.0, gold_preproc=False):
|
debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1):
|
||||||
gold_train = list(read_json_file(train_loc))
|
gold_train = list(read_json_file(train_loc))
|
||||||
#taggings = get_train_tags(English, model_dir, gold_train, gold_preproc)
|
#taggings = get_train_tags(English, model_dir, gold_train, gold_preproc)
|
||||||
taggings = None
|
taggings = None
|
||||||
|
@ -260,7 +262,7 @@ def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbos
|
||||||
feat_set='basic' if not debug else 'debug',
|
feat_set='basic' if not debug else 'debug',
|
||||||
gold_preproc=gold_preproc, n_sents=n_sents,
|
gold_preproc=gold_preproc, n_sents=n_sents,
|
||||||
corruption_level=corruption_level, n_iter=n_iter,
|
corruption_level=corruption_level, n_iter=n_iter,
|
||||||
train_tags=taggings)
|
train_tags=taggings, beam_width=beam_width)
|
||||||
if out_loc:
|
if out_loc:
|
||||||
write_parses(English, dev_loc, model_dir, out_loc)
|
write_parses(English, dev_loc, model_dir, out_loc)
|
||||||
scorer = evaluate(English, list(read_json_file(dev_loc)),
|
scorer = evaluate(English, list(read_json_file(dev_loc)),
|
||||||
|
|
Loading…
Reference in New Issue