From a676d668070cc3a23de30db99e55ee5f7593f515 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 2 Feb 2016 22:29:34 +0100 Subject: [PATCH] * Update the CoNLL train script, to get working on other languages --- bin/parser/conll_train.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/bin/parser/conll_train.py b/bin/parser/conll_train.py index 40c7b71f8..da9bb807a 100755 --- a/bin/parser/conll_train.py +++ b/bin/parser/conll_train.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals import os from os import path import shutil -import codecs +import io import random import time import gzip @@ -56,12 +56,20 @@ def _parse_line(line): if len(pieces) == 4: word, pos, head_idx, label = pieces head_idx = int(head_idx) + elif len(pieces) == 15: + id_ = int(pieces[0].split('_')[-1]) + word = pieces[1] + pos = pieces[4] + head_idx = int(pieces[8])-1 + label = pieces[10] else: - id_ = int(pieces[0]) + id_ = int(pieces[0].split('_')[-1]) word = pieces[1] pos = pieces[4] head_idx = int(pieces[6])-1 label = pieces[7] + if head_idx == 0: + label = 'ROOT' return word, pos, head_idx, label @@ -69,8 +77,8 @@ def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False): tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) nlp.tagger(tokens) nlp.parser(tokens) - gold = GoldParse(tokens, annot_tuples) - scorer.score(tokens, gold, verbose=verbose) + gold = GoldParse(tokens, annot_tuples, make_projective=False) + scorer.score(tokens, gold, verbose=verbose, punct_labels=('--', 'p', 'punct')) def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0, @@ -122,11 +130,11 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0 def main(train_loc, dev_loc, model_dir): - with codecs.open(train_loc, 'r', 'utf8') as file_: + with io.open(train_loc, 'r', encoding='utf8') as file_: train_sents = read_conll(file_) - train(English, train_sents, model_dir) + #train(English, train_sents, model_dir) nlp = English(data_dir=model_dir) - dev_sents = read_conll(open(dev_loc)) + dev_sents = read_conll(io.open(dev_loc, 'r', encoding='utf8')) scorer = Scorer() for _, sents in dev_sents: for annot_tuples, _ in sents: