diff --git a/bin/parser/train.py b/bin/parser/train.py index 267b26275..489d9259c 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -165,17 +165,29 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False return scorer -def write_parses(Language, dev_loc, model_dir, out_loc): - nlp = Language() - gold_tuples = read_docparse_file(dev_loc) +def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None): + nlp = Language(data_dir=model_dir) + if beam_width is not None: + nlp.parser.cfg.beam_width = beam_width + gold_tuples = read_json_file(dev_loc) scorer = Scorer() out_file = codecs.open(out_loc, 'w', 'utf8') - for raw_text, segmented_text, annot_tuples in gold_tuples: - tokens = nlp(raw_text) - for t in tokens: - out_file.write( - '%s\t%s\t%s\t%s\n' % (t.orth_, t.tag_, t.head.orth_, t.dep_) - ) + for raw_text, sents in gold_tuples: + sents = _merge_sents(sents) + for annot_tuples, brackets in sents: + if raw_text is None: + tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) + nlp.tagger(tokens) + nlp.entity(tokens) + nlp.parser(tokens) + else: + tokens = nlp(raw_text, merge_mwes=False) + gold = GoldParse(tokens, annot_tuples) + scorer.score(tokens, gold, verbose=False) + for t in tokens: + out_file.write( + '%s\t%s\t%s\t%s\n' % (t.orth_, t.tag_, t.head.orth_, t.dep_) + ) return scorer @@ -204,7 +216,7 @@ def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbos corruption_level=corruption_level, n_iter=n_iter, beam_width=beam_width) if out_loc: - write_parses(English, dev_loc, model_dir, out_loc) + write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width) scorer = evaluate(English, list(read_json_file(dev_loc)), model_dir, gold_preproc=gold_preproc, verbose=verbose, beam_width=beam_width)