mirror of https://github.com/explosion/spaCy.git
* Update the CoNLL train script, to get working on other languages
This commit is contained in:
parent
6c633f2edc
commit
a676d66807
|
@ -5,7 +5,7 @@ from __future__ import unicode_literals
|
|||
import os
|
||||
from os import path
|
||||
import shutil
|
||||
import codecs
|
||||
import io
|
||||
import random
|
||||
import time
|
||||
import gzip
|
||||
|
@ -56,12 +56,20 @@ def _parse_line(line):
|
|||
if len(pieces) == 4:
|
||||
word, pos, head_idx, label = pieces
|
||||
head_idx = int(head_idx)
|
||||
elif len(pieces) == 15:
|
||||
id_ = int(pieces[0].split('_')[-1])
|
||||
word = pieces[1]
|
||||
pos = pieces[4]
|
||||
head_idx = int(pieces[8])-1
|
||||
label = pieces[10]
|
||||
else:
|
||||
id_ = int(pieces[0])
|
||||
id_ = int(pieces[0].split('_')[-1])
|
||||
word = pieces[1]
|
||||
pos = pieces[4]
|
||||
head_idx = int(pieces[6])-1
|
||||
label = pieces[7]
|
||||
if head_idx == 0:
|
||||
label = 'ROOT'
|
||||
return word, pos, head_idx, label
|
||||
|
||||
|
||||
|
@ -69,8 +77,8 @@ def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
|||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
nlp.parser(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
gold = GoldParse(tokens, annot_tuples, make_projective=False)
|
||||
scorer.score(tokens, gold, verbose=verbose, punct_labels=('--', 'p', 'punct'))
|
||||
|
||||
|
||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||
|
@ -122,11 +130,11 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
|||
|
||||
|
||||
def main(train_loc, dev_loc, model_dir):
|
||||
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
||||
with io.open(train_loc, 'r', encoding='utf8') as file_:
|
||||
train_sents = read_conll(file_)
|
||||
train(English, train_sents, model_dir)
|
||||
#train(English, train_sents, model_dir)
|
||||
nlp = English(data_dir=model_dir)
|
||||
dev_sents = read_conll(open(dev_loc))
|
||||
dev_sents = read_conll(io.open(dev_loc, 'r', encoding='utf8'))
|
||||
scorer = Scorer()
|
||||
for _, sents in dev_sents:
|
||||
for annot_tuples, _ in sents:
|
||||
|
|
Loading…
Reference in New Issue