diff --git a/bin/parser/train.py b/bin/parser/train.py index 63b0d47bb..5d588a317 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -17,6 +17,7 @@ import spacy.util from spacy.syntax.util import Config from spacy.gold import read_json_file from spacy.gold import GoldParse +from spacy.gold import merge_sents from spacy.scorer import Scorer @@ -63,22 +64,6 @@ def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False): scorer.score(tokens, gold, verbose=verbose) -def _merge_sents(sents): - m_deps = [[], [], [], [], [], []] - m_brackets = [] - i = 0 - for (ids, words, tags, heads, labels, ner), brackets in sents: - m_deps[0].extend(id_ + i for id_ in ids) - m_deps[1].extend(words) - m_deps[2].extend(tags) - m_deps[3].extend(head + i for head in heads) - m_deps[4].extend(labels) - m_deps[5].extend(ner) - m_brackets.extend((b['first'] + i, b['last'] + i, b['label']) for b in brackets) - i += len(ids) - return [(m_deps, m_brackets)] - - def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, entity_cfg, n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0): print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %") @@ -86,10 +71,11 @@ def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, ent with Language.train(model_dir, train_data, tagger_cfg, parser_cfg, entity_cfg) as trainer: loss = 0 - for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)): + for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=gold_preproc, + augment_data=None)): for doc, gold in epoch: trainer.update(doc, gold) - dev_scores = trainer.evaluate(dev_data) + dev_scores = trainer.evaluate(dev_data, gold_preproc=gold_preproc) print(format_str.format(itn, loss, **dev_scores.scores)) @@ -105,7 +91,7 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False if gold_preproc: raw_text = None else: - sents = _merge_sents(sents) + sents = merge_sents(sents) for annot_tuples, brackets in sents: if raw_text is None: tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])