From d1eea2d865b0b42d02195143788343ea3eb620b3 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 6 Sep 2015 17:51:48 +0200 Subject: [PATCH] * Update train.py for language-generic spaCy --- bin/parser/train.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/bin/parser/train.py b/bin/parser/train.py index 68217fcb3..abd5eb16e 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -14,7 +14,6 @@ import re import spacy.util from spacy.en import English -from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir from spacy.syntax.util import Config from spacy.gold import read_json_file @@ -22,6 +21,11 @@ from spacy.gold import GoldParse from spacy.scorer import Scorer +from spacy.syntax.arc_eager import ArcEager +from spacy.syntax.ner import BiluoPushDown +from spacy.tagger import Tagger +from spacy.syntax.parser import Parser + def _corrupt(c, noise_level): if random.random() >= noise_level: @@ -80,32 +84,28 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', beam_width=1, verbose=False, use_orig_arc_eager=False): dep_model_dir = path.join(model_dir, 'deps') - pos_model_dir = path.join(model_dir, 'pos') ner_model_dir = path.join(model_dir, 'ner') if path.exists(dep_model_dir): shutil.rmtree(dep_model_dir) - if path.exists(pos_model_dir): - shutil.rmtree(pos_model_dir) if path.exists(ner_model_dir): shutil.rmtree(ner_model_dir) os.mkdir(dep_model_dir) - os.mkdir(pos_model_dir) os.mkdir(ner_model_dir) - setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir) - Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, - labels=Language.ParserTransitionSystem.get_labels(gold_tuples), + labels=ArcEager.get_labels(gold_tuples), beam_width=beam_width) Config.write(ner_model_dir, 'config', features='ner', seed=seed, - labels=Language.EntityTransitionSystem.get_labels(gold_tuples), + labels=BiluoPushDown.get_labels(gold_tuples), beam_width=0) if n_sents > 0: gold_tuples = gold_tuples[:n_sents] - nlp = Language(data_dir=model_dir) - + nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False) + nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates()) + nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager) + nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown) print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %") for itn in range(n_iter): scorer = Scorer() @@ -140,7 +140,7 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', print('%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f, scorer.tags_acc, scorer.token_acc)) - nlp.end_training() + nlp.end_training(model_dir) def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False, beam_width=None):