diff --git a/bin/parser/train.py b/bin/parser/train.py index 747bf978d..89d9b1f4c 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -61,13 +61,8 @@ def read_docparse_gold(file_): tags = [] ids = [] lines = sent_str.strip().split('\n') -<<<<<<< HEAD raw_text = lines.pop(0).strip() tok_text = lines.pop(0).strip() -======= - raw_text = lines.pop(0) - tok_text = lines.pop(0) ->>>>>>> master for i, line in enumerate(lines): id_, word, pos_string, head_idx, label = _parse_line(line) if label == 'root': @@ -200,9 +195,9 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir) - left_labels, right_labels = get_labels(paragraphs) + labels = Language.ParserTransitionSystem.get_labels(gold_sents) Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, - left_labels=left_labels, right_labels=right_labels) + labels=labels) nlp = Language() @@ -210,14 +205,12 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, heads_corr = 0 pos_corr = 0 n_tokens = 0 - for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer, - gold_preproc=gold_preproc): + for gold_sent in gold_sents: + tokens = nlp.tokenizer(gold_sent.raw) + gold_sent.align_to_tokens(tokens) nlp.tagger(tokens) - try: - heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=force_gold) - except OracleError: - continue - pos_corr += nlp.tagger.train(tokens, tag_strs) + heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold) + pos_corr += nlp.tagger.train(tokens, gold_parse.tags) n_tokens += len(tokens) acc = float(heads_corr) / n_tokens pos_acc = float(pos_corr) / n_tokens @@ -265,10 +258,9 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): def main(train_loc, dev_loc, model_dir): - with codecs.open(train_loc, 'r', 'utf8') as file_: - train_sents = read_docparse_gold(file_) - train(English, train_sents, model_dir, gold_preproc=False, force_gold=False) - print evaluate(English, dev_loc, model_dir, gold_preproc=False) + train(English, read_docparse_gold(train_loc), model_dir, + gold_preproc=False, force_gold=False) + print evaluate(English, read_docparse_gold(dev_loc), model_dir, gold_preproc=False) if __name__ == '__main__':