mirror of https://github.com/explosion/spaCy.git
* Wire hyperparameters to script interface
This commit is contained in:
parent
ebe630cc8d
commit
da793073d0
|
@ -84,7 +84,8 @@ def _merge_sents(sents):
|
||||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
eta=0.01, mu=0.9, n_hidden=100, word_vec_len=10, pos_vec_len=10):
|
eta=0.01, mu=0.9, n_hidden=100,
|
||||||
|
nv_word=10, nv_tag=10, nv_label=10):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
ner_model_dir = path.join(model_dir, 'ner')
|
ner_model_dir = path.join(model_dir, 'ner')
|
||||||
|
@ -99,8 +100,15 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
os.mkdir(ner_model_dir)
|
os.mkdir(ner_model_dir)
|
||||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
|
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
|
||||||
|
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config',
|
||||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples))
|
seed=seed,
|
||||||
|
features=feat_set,
|
||||||
|
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
||||||
|
vector_lengths=(nv_word, nv_tag, nv_label),
|
||||||
|
hidden_nodes=n_hidden,
|
||||||
|
eta=eta,
|
||||||
|
mu=mu
|
||||||
|
)
|
||||||
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
||||||
labels=Language.EntityTransitionSystem.get_labels(gold_tuples),
|
labels=Language.EntityTransitionSystem.get_labels(gold_tuples),
|
||||||
beam_width=0)
|
beam_width=0)
|
||||||
|
@ -110,16 +118,17 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
|
|
||||||
nlp = Language(data_dir=model_dir)
|
nlp = Language(data_dir=model_dir)
|
||||||
|
|
||||||
def make_model(n_classes, input_spec, model_dir):
|
def make_model(n_classes, (words, tags, labels), model_dir):
|
||||||
print input_spec
|
n_in = (nv_word * len(words)) + \
|
||||||
n_in = sum(n_cols * len(fields) for (n_cols, fields) in input_spec)
|
(nv_tag * len(tags)) + \
|
||||||
|
(nv_label * len(labels))
|
||||||
print 'Compiling'
|
print 'Compiling'
|
||||||
debug, train_func, predict_func = compile_theano_model(n_classes, n_hidden,
|
debug, train_func, predict_func = compile_theano_model(n_classes, n_hidden,
|
||||||
n_in, 0.0, 0.0)
|
n_in, 0.0, 0.0)
|
||||||
print 'Done'
|
print 'Done'
|
||||||
return TheanoModel(
|
return TheanoModel(
|
||||||
n_classes,
|
n_classes,
|
||||||
input_spec,
|
((nv_word, words), (nv_tag, tags), (nv_label, labels)),
|
||||||
train_func,
|
train_func,
|
||||||
predict_func,
|
predict_func,
|
||||||
model_loc=model_dir,
|
model_loc=model_dir,
|
||||||
|
@ -226,14 +235,23 @@ def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None):
|
||||||
n_sents=("Number of training sentences", "option", "n", int),
|
n_sents=("Number of training sentences", "option", "n", int),
|
||||||
n_iter=("Number of training iterations", "option", "i", int),
|
n_iter=("Number of training iterations", "option", "i", int),
|
||||||
verbose=("Verbose error reporting", "flag", "v", bool),
|
verbose=("Verbose error reporting", "flag", "v", bool),
|
||||||
debug=("Debug mode", "flag", "d", bool),
|
|
||||||
|
nv_word=("Word vector length", "option", "W", int),
|
||||||
|
nv_tag=("Tag vector length", "option", "T", int),
|
||||||
|
nv_label=("Label vector length", "option", "L", int),
|
||||||
|
nv_hidden=("Hidden nodes length", "option", "H", int),
|
||||||
|
eta=("Learning rate", "option", "E", float),
|
||||||
|
mu=("Momentum", "option", "M", float),
|
||||||
)
|
)
|
||||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||||
debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1,
|
corruption_level=0.0, gold_preproc=False,
|
||||||
|
nv_word=10, nv_tag=10, nv_label=10, nv_hidden=10,
|
||||||
|
eta=0.1, mu=0.9,
|
||||||
eval_only=False):
|
eval_only=False):
|
||||||
gold_train = list(read_json_file(train_loc))
|
gold_train = list(read_json_file(train_loc))
|
||||||
nlp = train(English, gold_train, model_dir,
|
nlp = train(English, gold_train, model_dir,
|
||||||
feat_set='embed',
|
feat_set='embed',
|
||||||
|
nv_word=nv_word, nv_tag=nv_tag, nv_label=nv_label,
|
||||||
gold_preproc=gold_preproc, n_sents=n_sents,
|
gold_preproc=gold_preproc, n_sents=n_sents,
|
||||||
corruption_level=corruption_level, n_iter=n_iter,
|
corruption_level=corruption_level, n_iter=n_iter,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
|
Loading…
Reference in New Issue