mirror of https://github.com/explosion/spaCy.git
142 lines
4.3 KiB
Python
Executable File
142 lines
4.3 KiB
Python
Executable File
#!/usr/bin/env python
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
import os
|
|
from os import path
|
|
import shutil
|
|
import codecs
|
|
import random
|
|
import time
|
|
import gzip
|
|
|
|
import plac
|
|
import cProfile
|
|
import pstats
|
|
|
|
import spacy.util
|
|
from spacy.en import English
|
|
from spacy.gold import GoldParse
|
|
|
|
from spacy.syntax.util import Config
|
|
from spacy.syntax.arc_eager import ArcEager
|
|
from spacy.syntax.parser import Parser
|
|
from spacy.scorer import Scorer
|
|
from spacy.tagger import Tagger
|
|
|
|
# Last updated for spaCy v0.97
|
|
|
|
|
|
def read_conll(file_):
|
|
"""Read a standard CoNLL/MALT-style format"""
|
|
sents = []
|
|
for sent_str in file_.read().strip().split('\n\n'):
|
|
ids = []
|
|
words = []
|
|
heads = []
|
|
labels = []
|
|
tags = []
|
|
for i, line in enumerate(sent_str.split('\n')):
|
|
word, pos_string, head_idx, label = _parse_line(line)
|
|
words.append(word)
|
|
if head_idx < 0:
|
|
head_idx = i
|
|
ids.append(i)
|
|
heads.append(head_idx)
|
|
labels.append(label)
|
|
tags.append(pos_string)
|
|
text = ' '.join(words)
|
|
annot = (ids, words, tags, heads, labels, ['O'] * len(ids))
|
|
sents.append((None, [(annot, [])]))
|
|
return sents
|
|
|
|
|
|
def _parse_line(line):
|
|
pieces = line.split()
|
|
if len(pieces) == 4:
|
|
word, pos, head_idx, label = pieces
|
|
head_idx = int(head_idx)
|
|
else:
|
|
id_ = int(pieces[0])
|
|
word = pieces[1]
|
|
pos = pieces[4]
|
|
head_idx = int(pieces[6])-1
|
|
label = pieces[7]
|
|
return word, pos, head_idx, label
|
|
|
|
|
|
def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
|
nlp.tagger(tokens)
|
|
nlp.parser(tokens)
|
|
gold = GoldParse(tokens, annot_tuples)
|
|
scorer.score(tokens, gold, verbose=verbose)
|
|
|
|
|
|
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
|
gold_preproc=False, force_gold=False):
|
|
dep_model_dir = path.join(model_dir, 'deps')
|
|
pos_model_dir = path.join(model_dir, 'pos')
|
|
if path.exists(dep_model_dir):
|
|
shutil.rmtree(dep_model_dir)
|
|
if path.exists(pos_model_dir):
|
|
shutil.rmtree(pos_model_dir)
|
|
os.mkdir(dep_model_dir)
|
|
os.mkdir(pos_model_dir)
|
|
|
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
|
labels=ArcEager.get_labels(gold_tuples))
|
|
|
|
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
|
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
|
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
|
|
|
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
|
|
for itn in range(n_iter):
|
|
scorer = Scorer()
|
|
loss = 0
|
|
for _, sents in gold_tuples:
|
|
for annot_tuples, _ in sents:
|
|
if len(annot_tuples[1]) == 1:
|
|
continue
|
|
|
|
score_model(scorer, nlp, None, annot_tuples, verbose=False)
|
|
|
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
|
nlp.tagger(tokens)
|
|
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
|
if not gold.is_projective:
|
|
raise Exception(
|
|
"Non-projective sentence in training, after we should "
|
|
"have enforced projectivity: %s" % annot_tuples
|
|
)
|
|
|
|
loss += nlp.parser.train(tokens, gold)
|
|
nlp.tagger.train(tokens, gold.tags)
|
|
random.shuffle(gold_tuples)
|
|
print('%d:\t%d\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas,
|
|
scorer.tags_acc, scorer.token_acc))
|
|
print('end training')
|
|
nlp.end_training(model_dir)
|
|
print('done')
|
|
|
|
|
|
def main(train_loc, dev_loc, model_dir):
|
|
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
|
train_sents = read_conll(file_)
|
|
train(English, train_sents, model_dir)
|
|
nlp = English(data_dir=model_dir)
|
|
dev_sents = read_conll(open(dev_loc))
|
|
scorer = Scorer()
|
|
for _, sents in dev_sents:
|
|
for annot_tuples, _ in sents:
|
|
score_model(scorer, nlp, None, annot_tuples)
|
|
print('TOK', 100-scorer.token_acc)
|
|
print('POS', scorer.tags_acc)
|
|
print('UAS', scorer.uas)
|
|
print('LAS', scorer.las)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
plac.call(main)
|