mirror of https://github.com/explosion/spaCy.git
* Add script to train models off the UD treebanks. Note that the UD data is restricted to research purposes only, and should only be used to train models for academic experiments.
This commit is contained in:
parent
b228a8f4a6
commit
6ea1601e93
|
@ -0,0 +1,151 @@
|
||||||
|
import plac
|
||||||
|
import json
|
||||||
|
from os import path
|
||||||
|
import shutil
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
from spacy.syntax.util import Config
|
||||||
|
from spacy.gold import GoldParse
|
||||||
|
from spacy.tokenizer import Tokenizer
|
||||||
|
from spacy.vocab import Vocab
|
||||||
|
from spacy.tagger import Tagger
|
||||||
|
from spacy.syntax.parser import Parser
|
||||||
|
from spacy.syntax.arc_eager import ArcEager
|
||||||
|
from spacy.syntax.parser import get_templates
|
||||||
|
from spacy.scorer import Scorer
|
||||||
|
|
||||||
|
from spacy.language import Language
|
||||||
|
|
||||||
|
from spacy.tagger import W_orth
|
||||||
|
|
||||||
|
TAGGER_TEMPLATES = (
|
||||||
|
(W_orth,),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from codecs import open
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TreebankParser(object):
|
||||||
|
@staticmethod
|
||||||
|
def setup_model_dir(model_dir, labels, templates, feat_set='basic', seed=0):
|
||||||
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
|
if path.exists(dep_model_dir):
|
||||||
|
shutil.rmtree(dep_model_dir)
|
||||||
|
if path.exists(pos_model_dir):
|
||||||
|
shutil.rmtree(pos_model_dir)
|
||||||
|
os.mkdir(dep_model_dir)
|
||||||
|
os.mkdir(pos_model_dir)
|
||||||
|
|
||||||
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
|
labels=labels)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dir(cls, tag_map, model_dir):
|
||||||
|
vocab = Vocab(tag_map=tag_map, get_lex_attr=Language.default_lex_attrs())
|
||||||
|
tokenizer = Tokenizer(vocab, {}, None, None, None)
|
||||||
|
tagger = Tagger.blank(vocab, TAGGER_TEMPLATES)
|
||||||
|
|
||||||
|
cfg = Config.read(path.join(model_dir, 'deps'), 'config')
|
||||||
|
parser = Parser.from_dir(path.join(model_dir, 'deps'), vocab.strings, ArcEager)
|
||||||
|
return cls(vocab, tokenizer, tagger, parser)
|
||||||
|
|
||||||
|
def __init__(self, vocab, tokenizer, tagger, parser):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.tagger = tagger
|
||||||
|
self.parser = parser
|
||||||
|
|
||||||
|
def train(self, words, tags, heads, deps):
|
||||||
|
tokens = self.tokenizer.tokens_from_list(list(words))
|
||||||
|
self.tagger.train(tokens, tags)
|
||||||
|
|
||||||
|
tokens = self.tokenizer.tokens_from_list(list(words))
|
||||||
|
ids = range(len(words))
|
||||||
|
ner = ['O'] * len(words)
|
||||||
|
gold = GoldParse(tokens, ((ids, words, tags, heads, deps, ner)),
|
||||||
|
make_projective=False)
|
||||||
|
self.tagger(tokens)
|
||||||
|
if gold.is_projective:
|
||||||
|
try:
|
||||||
|
self.parser.train(tokens, gold)
|
||||||
|
except:
|
||||||
|
for id_, word, head, dep in zip(ids, words, heads, deps):
|
||||||
|
print(id_, word, head, dep)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def __call__(self, words, tags=None):
|
||||||
|
tokens = self.tokenizer.tokens_from_list(list(words))
|
||||||
|
if tags is None:
|
||||||
|
self.tagger(tokens)
|
||||||
|
else:
|
||||||
|
self.tagger.tag_from_strings(tokens, tags)
|
||||||
|
self.parser(tokens)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def end_training(self, data_dir):
|
||||||
|
self.parser.model.end_training(path.join(data_dir, 'deps', 'model'))
|
||||||
|
self.tagger.model.end_training(path.join(data_dir, 'pos', 'model'))
|
||||||
|
self.vocab.strings.dump(path.join(data_dir, 'vocab', 'strings.txt'))
|
||||||
|
|
||||||
|
|
||||||
|
def read_conllx(loc):
|
||||||
|
with open(loc, 'r', 'utf8') as file_:
|
||||||
|
text = file_.read()
|
||||||
|
for sent in text.strip().split('\n\n'):
|
||||||
|
lines = sent.strip().split('\n')
|
||||||
|
if lines:
|
||||||
|
if lines[0].startswith('#'):
|
||||||
|
lines.pop(0)
|
||||||
|
tokens = []
|
||||||
|
for line in lines:
|
||||||
|
id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split()
|
||||||
|
if '-' in id_:
|
||||||
|
continue
|
||||||
|
id_ = int(id_) - 1
|
||||||
|
head = (int(head) - 1) if head != '0' else id_
|
||||||
|
dep = 'ROOT' if dep == 'root' else dep
|
||||||
|
tokens.append((id_, word, tag, head, dep, 'O'))
|
||||||
|
tuples = zip(*tokens)
|
||||||
|
yield (None, [(tuples, [])])
|
||||||
|
|
||||||
|
|
||||||
|
def score_model(nlp, gold_docs, verbose=False):
|
||||||
|
scorer = Scorer()
|
||||||
|
for _, gold_doc in gold_docs:
|
||||||
|
for annot_tuples, _ in gold_doc:
|
||||||
|
tokens = nlp(list(annot_tuples[1]), tags=list(annot_tuples[2]))
|
||||||
|
gold = GoldParse(tokens, annot_tuples)
|
||||||
|
scorer.score(tokens, gold, verbose=verbose)
|
||||||
|
return scorer
|
||||||
|
|
||||||
|
|
||||||
|
def main(train_loc, dev_loc, model_dir, tag_map_loc):
|
||||||
|
with open(tag_map_loc) as file_:
|
||||||
|
tag_map = json.loads(file_.read())
|
||||||
|
train_sents = list(read_conllx(train_loc))
|
||||||
|
labels = ArcEager.get_labels(train_sents)
|
||||||
|
templates = get_templates('basic')
|
||||||
|
|
||||||
|
TreebankParser.setup_model_dir(model_dir, labels, templates)
|
||||||
|
|
||||||
|
nlp = TreebankParser.from_dir(tag_map, model_dir)
|
||||||
|
|
||||||
|
for itn in range(15):
|
||||||
|
for _, doc_sents in train_sents:
|
||||||
|
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||||
|
nlp.train(words, tags, heads, deps)
|
||||||
|
random.shuffle(train_sents)
|
||||||
|
scorer = score_model(nlp, read_conllx(dev_loc))
|
||||||
|
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
|
||||||
|
nlp.end_training(model_dir)
|
||||||
|
scorer = score_model(nlp, read_conllx(dev_loc))
|
||||||
|
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
plac.call(main)
|
Loading…
Reference in New Issue