* Update the CoNLL train script, to get working on other languages

This commit is contained in:
Matthew Honnibal 2016-02-02 22:29:34 +01:00
parent 6c633f2edc
commit a676d66807
1 changed files with 15 additions and 7 deletions

View File

@ -5,7 +5,7 @@ from __future__ import unicode_literals
import os
from os import path
import shutil
import codecs
import io
import random
import time
import gzip
@ -56,12 +56,20 @@ def _parse_line(line):
if len(pieces) == 4:
word, pos, head_idx, label = pieces
head_idx = int(head_idx)
elif len(pieces) == 15:
id_ = int(pieces[0].split('_')[-1])
word = pieces[1]
pos = pieces[4]
head_idx = int(pieces[8])-1
label = pieces[10]
else:
id_ = int(pieces[0])
id_ = int(pieces[0].split('_')[-1])
word = pieces[1]
pos = pieces[4]
head_idx = int(pieces[6])-1
label = pieces[7]
if head_idx == 0:
label = 'ROOT'
return word, pos, head_idx, label
@ -69,8 +77,8 @@ def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
nlp.tagger(tokens)
nlp.parser(tokens)
gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=verbose)
gold = GoldParse(tokens, annot_tuples, make_projective=False)
scorer.score(tokens, gold, verbose=verbose, punct_labels=('--', 'p', 'punct'))
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0,
@ -122,11 +130,11 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
def main(train_loc, dev_loc, model_dir):
with codecs.open(train_loc, 'r', 'utf8') as file_:
with io.open(train_loc, 'r', encoding='utf8') as file_:
train_sents = read_conll(file_)
train(English, train_sents, model_dir)
#train(English, train_sents, model_dir)
nlp = English(data_dir=model_dir)
dev_sents = read_conll(open(dev_loc))
dev_sents = read_conll(io.open(dev_loc, 'r', encoding='utf8'))
scorer = Scorer()
for _, sents in dev_sents:
for annot_tuples, _ in sents: