mirror of https://github.com/explosion/spaCy.git
Add L1 penalty option to parser
This commit is contained in:
parent
798450136d
commit
35124b144a
|
@ -66,7 +66,7 @@ def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
||||||
|
|
||||||
def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, entity_cfg,
|
def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, entity_cfg,
|
||||||
n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0):
|
n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0):
|
||||||
print("Itn.\tP.Loss\tN feats\tUAS\tNER F.\tTag %\tToken %")
|
print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %")
|
||||||
format_str = '{:d}\t{:d}\t{:d}\t{uas:.3f}\t{ents_f:.3f}\t{tags_acc:.3f}\t{token_acc:.3f}'
|
format_str = '{:d}\t{:d}\t{:d}\t{uas:.3f}\t{ents_f:.3f}\t{tags_acc:.3f}\t{token_acc:.3f}'
|
||||||
with Language.train(model_dir, train_data,
|
with Language.train(model_dir, train_data,
|
||||||
tagger_cfg, parser_cfg, entity_cfg) as trainer:
|
tagger_cfg, parser_cfg, entity_cfg) as trainer:
|
||||||
|
@ -76,12 +76,13 @@ def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, ent
|
||||||
for doc, gold in epoch:
|
for doc, gold in epoch:
|
||||||
trainer.update(doc, gold)
|
trainer.update(doc, gold)
|
||||||
dev_scores = trainer.evaluate(dev_data, gold_preproc=gold_preproc)
|
dev_scores = trainer.evaluate(dev_data, gold_preproc=gold_preproc)
|
||||||
print(format_str.format(itn, loss,
|
print(format_str.format(itn, trainer.nlp.parser.model.nr_weight,
|
||||||
trainer.nlp.parser.model.nr_active_feat, **dev_scores.scores))
|
trainer.nlp.parser.model.nr_active_feat, **dev_scores.scores))
|
||||||
|
|
||||||
|
|
||||||
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
|
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
|
||||||
beam_width=None, cand_preproc=None):
|
beam_width=None, cand_preproc=None):
|
||||||
|
print("Load parser", model_dir)
|
||||||
nlp = Language(path=model_dir)
|
nlp = Language(path=model_dir)
|
||||||
if nlp.lang == 'de':
|
if nlp.lang == 'de':
|
||||||
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
|
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
|
||||||
|
@ -146,22 +147,25 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
||||||
verbose=("Verbose error reporting", "flag", "v", bool),
|
verbose=("Verbose error reporting", "flag", "v", bool),
|
||||||
debug=("Debug mode", "flag", "d", bool),
|
debug=("Debug mode", "flag", "d", bool),
|
||||||
pseudoprojective=("Use pseudo-projective parsing", "flag", "p", bool),
|
pseudoprojective=("Use pseudo-projective parsing", "flag", "p", bool),
|
||||||
|
L1=("L1 regularization penalty", "option", "L", float),
|
||||||
)
|
)
|
||||||
def main(language, train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
def main(language, train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||||
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False, pseudoprojective=False):
|
debug=False, corruption_level=0.0, gold_preproc=False, eval_only=False, pseudoprojective=False,
|
||||||
|
L1=1e-6):
|
||||||
parser_cfg = dict(locals())
|
parser_cfg = dict(locals())
|
||||||
tagger_cfg = dict(locals())
|
tagger_cfg = dict(locals())
|
||||||
entity_cfg = dict(locals())
|
entity_cfg = dict(locals())
|
||||||
|
|
||||||
lang = spacy.util.get_lang_class(language)
|
lang = spacy.util.get_lang_class(language)
|
||||||
|
|
||||||
parser_cfg['features'] = lang.Defaults.parser_features
|
parser_cfg['features'] = lang.Defaults.parser_features
|
||||||
entity_cfg['features'] = lang.Defaults.entity_features
|
entity_cfg['features'] = lang.Defaults.entity_features
|
||||||
|
|
||||||
if not eval_only:
|
if not eval_only:
|
||||||
gold_train = list(read_json_file(train_loc))
|
gold_train = list(read_json_file(train_loc))
|
||||||
gold_dev = list(read_json_file(dev_loc))
|
gold_dev = list(read_json_file(dev_loc))
|
||||||
gold_train = gold_train[:n_sents]
|
if n_sents > 0:
|
||||||
|
gold_train = gold_train[:n_sents]
|
||||||
train(lang, gold_train, gold_dev, model_dir, tagger_cfg, parser_cfg, entity_cfg,
|
train(lang, gold_train, gold_dev, model_dir, tagger_cfg, parser_cfg, entity_cfg,
|
||||||
n_sents=n_sents, gold_preproc=gold_preproc, corruption_level=corruption_level,
|
n_sents=n_sents, gold_preproc=gold_preproc, corruption_level=corruption_level,
|
||||||
n_iter=n_iter)
|
n_iter=n_iter)
|
||||||
|
|
Loading…
Reference in New Issue