From b4348ce1c38eb9ccee3988372f39b8408f209a9b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 29 Jan 2015 04:21:13 +1100 Subject: [PATCH] * Messily use unsegmented sentences to train the parser --- bin/parser/train.py | 63 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/bin/parser/train.py b/bin/parser/train.py index 0b214a20c..67f01ee95 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -26,6 +26,7 @@ def read_tokenized_gold(file_): """Read a standard CoNLL/MALT-style format""" sents = [] for sent_str in file_.read().strip().split('\n\n'): + ids = [] words = [] heads = [] labels = [] @@ -35,10 +36,11 @@ def read_tokenized_gold(file_): words.append(word) if head_idx == -1: head_idx = i + ids.append(id_) heads.append(head_idx) labels.append(label) tags.append(pos_string) - sents.append((words, heads, labels, tags)) + sents.append((ids_, words, heads, labels, tags)) return sents @@ -49,31 +51,62 @@ def read_docparse_gold(file_): heads = [] labels = [] tags = [] + ids = [] lines = sent_str.strip().split('\n') raw_text = lines[0] tok_text = lines[1] for i, line in enumerate(lines[2:]): - word, pos_string, head_idx, label = _parse_line(line) + id_, word, pos_string, head_idx, label = _parse_line(line) + if label == 'root': + label = 'ROOT' + if pos_string == "``": + word = "``" + elif pos_string == "''": + word = "''" words.append(word) - if head_idx == -1: - head_idx = i + if head_idx < 0: + head_idx = id_ + ids.append(id_) heads.append(head_idx) labels.append(label) tags.append(pos_string) - words = tok_text.replace('', ' ').replace('', ' ').split(' ') + heads = _map_indices_to_tokens(ids, heads) + words = tok_text.replace('', ' ').replace('', ' ').split() + #print words + #print heads sents.append((words, heads, labels, tags)) + #sent_strings = tok_text.split('') + #for sent in sent_strings: + # sent_words = sent.replace('', ' ').split(' ') + # sent_heads = [] + # sent_labels = [] + # sent_tags = [] + # sent_ids = [] + # while len(sent_heads) < len(sent_words): + # sent_heads.append(heads.pop(0)) + # sent_labels.append(labels.pop(0)) + # sent_tags.append(tags.pop(0)) + # sent_ids.append(ids.pop(0)) + # sent_heads = _map_indices_to_tokens(sent_ids, sent_heads) + # sents.append((sent_words, sent_heads, sent_labels, sent_tags)) return sents +def _map_indices_to_tokens(ids, heads): + return [ids.index(head) for head in heads] + + + def _parse_line(line): pieces = line.split() if len(pieces) == 4: - return pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3] + return 0, pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3] else: + id_ = int(pieces[0]) word = pieces[1] pos = pieces[3] - head_idx = int(pieces[6]) - 1 + head_idx = int(pieces[6]) label = pieces[7] - return word, pos, head_idx, label + return id_, word, pos, head_idx, label def get_labels(sents): left_labels = set() @@ -113,7 +146,11 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0): tags = [nlp.tagger.tag_names.index(tag) for tag in tags] tokens = nlp.tokenizer.tokens_from_list(words) nlp.tagger(tokens) - heads_corr += nlp.parser.train_sent(tokens, heads, labels) + try: + heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False) + except: + print heads + raise pos_corr += nlp.tagger.train(tokens, tags) n_tokens += len(tokens) acc = float(heads_corr) / n_tokens @@ -122,7 +159,6 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0): random.shuffle(sents) nlp.parser.model.end_training() nlp.tagger.model.end_training() - #nlp.parser.model.dump(path.join(dep_model_dir, 'model'), freq_thresh=0) return acc @@ -131,13 +167,13 @@ def evaluate(Language, dev_loc, model_dir): n_corr = 0 total = 0 with codecs.open(dev_loc, 'r', 'utf8') as file_: - sents = read_tokenized_gold(file_) + sents = read_docparse_gold(file_) for words, heads, labels, tags in sents: tokens = nlp.tokenizer.tokens_from_list(words) nlp.tagger(tokens) nlp.parser(tokens) for i, token in enumerate(tokens): - #print i, token.string, i + token.head, heads[i], labels[i] + #print i, token.orth_, token.head.orth_, tokens[heads[i]].orth_, labels[i], token.head.i == heads[i] if labels[i] == 'P' or labels[i] == 'punct': continue n_corr += token.head.i == heads[i] @@ -150,7 +186,8 @@ PROFILE = False def main(train_loc, dev_loc, model_dir): with codecs.open(train_loc, 'r', 'utf8') as file_: - train_sents = read_tokenized_gold(file_) + train_sents = read_docparse_gold(file_) + train_sents = train_sents if PROFILE: import cProfile import pstats