From ca7577d8a9a489cd086b37478c5921940c414814 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 30 Jan 2015 16:36:24 +1100 Subject: [PATCH] * Allow parsers and taggers to be trained on text without gold pre-processing. --- bin/parser/train.py | 43 +++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/bin/parser/train.py b/bin/parser/train.py index d901a914f..eb83edb63 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -9,6 +9,7 @@ import codecs import random import time import gzip +import nltk import plac import cProfile @@ -22,6 +23,10 @@ from spacy.syntax.parser import GreedyParser from spacy.syntax.util import Config +def is_punct_label(label): + return label == 'P' or label.lower() == 'punct' + + def read_tokenized_gold(file_): """Read a standard CoNLL/MALT-style format""" sents = [] @@ -96,21 +101,21 @@ def _parse_line(line): +loss = 0 def _align_annotations_to_non_gold_tokens(tokens, words, annot): + global loss tags = [] heads = [] labels = [] - loss = 0 - print [t.orth_ for t in tokens] - print words + orig_words = list(words) + missed = [] for token in tokens: - print token.orth_, words[0] - print token.idx, annot[0][0] while annot and token.idx > annot[0][0]: - print 'pop', token.idx, annot[0][0] - annot.pop(0) - words.pop(0) - loss += 1 + miss_id, miss_tag, miss_head, miss_label = annot.pop(0) + miss_w = words.pop(0) + if not is_punct_label(miss_label): + missed.append(miss_w) + loss += 1 if not annot: tags.append(None) heads.append(None) @@ -129,6 +134,11 @@ def _align_annotations_to_non_gold_tokens(tokens, words, annot): labels.append(None) else: raise StandardError + #if missed: + # print orig_words + # print missed + # for t in tokens: + # print t.idx, t.orth_ return loss, tags, heads, labels @@ -137,7 +147,8 @@ def iter_data(paragraphs, tokenizer, gold_preproc=False): if not gold_preproc: tokens = tokenizer(raw) loss, tags, heads, labels = _align_annotations_to_non_gold_tokens( - tokens, words, zip(ids, tags, heads, labels)) + tokens, list(words), + zip(ids, tags, heads, labels)) ids = [t.idx for t in tokens] heads = _map_indices_to_tokens(ids, heads) yield tokens, tags, heads, labels @@ -170,7 +181,7 @@ def get_labels(sents): def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, - gold_preproc=True): + gold_preproc=False): dep_model_dir = path.join(model_dir, 'deps') pos_model_dir = path.join(model_dir, 'pos') if path.exists(dep_model_dir): @@ -194,10 +205,9 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, n_tokens = 0 for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer, gold_preproc=gold_preproc): - tags = [nlp.tagger.tag_names.index(tag) for tag in tag_strs] nlp.tagger(tokens) heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False) - pos_corr += nlp.tagger.train(tokens, tags) + pos_corr += nlp.tagger.train(tokens, tag_strs) n_tokens += len(tokens) acc = float(heads_corr) / n_tokens pos_acc = float(pos_corr) / n_tokens @@ -223,12 +233,13 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False): for i, token in enumerate(tokens): if heads[i] is None: skipped += 1 - if labels[i] == 'P' or labels[i] == 'punct': + continue + if is_punct_label(labels[i]): continue n_corr += token.head.i == heads[i] total += 1 - print skipped - return float(n_corr) / total + print loss, skipped, (loss+skipped + total) + return float(n_corr) / (total + loss) def main(train_loc, dev_loc, model_dir):