Fix train_ud script, which trains models from the Universal Dependencies format.

This commit is contained in:
Matthew Honnibal 2016-11-25 11:19:33 -06:00
parent 6dd3b94fa6
commit da5f0cce36
1 changed files with 46 additions and 104 deletions

View File

@ -6,103 +6,24 @@ import os
import random import random
import io import io
from spacy.syntax.util import Config from spacy.tokens import Doc
from spacy.syntax.nonproj import PseudoProjectivity
from spacy.language import Language
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.tokenizer import Tokenizer
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tagger import Tagger from spacy.tagger import Tagger
from spacy.syntax.parser import Parser from spacy.pipeline import DependencyParser
from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.parser import get_templates from spacy.syntax.parser import get_templates
from spacy.syntax.arc_eager import ArcEager
from spacy.scorer import Scorer from spacy.scorer import Scorer
import spacy.attrs import spacy.attrs
from spacy.language import Language
from spacy.tagger import W_orth
TAGGER_TEMPLATES = (
(W_orth,),
)
try: try:
from codecs import open from codecs import open
except ImportError: except ImportError:
pass pass
class TreebankParser(object):
@staticmethod
def setup_model_dir(model_dir, labels, templates, feat_set='basic', seed=0):
dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos')
if path.exists(dep_model_dir):
shutil.rmtree(dep_model_dir)
if path.exists(pos_model_dir):
shutil.rmtree(pos_model_dir)
os.mkdir(dep_model_dir)
os.mkdir(pos_model_dir)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=labels)
@classmethod
def from_dir(cls, tag_map, model_dir):
vocab = Vocab(tag_map=tag_map, get_lex_attr=Language.default_lex_attrs())
vocab.get_lex_attr[spacy.attrs.LANG] = lambda _: 0
tokenizer = Tokenizer(vocab, {}, None, None, None)
tagger = Tagger.blank(vocab, TAGGER_TEMPLATES)
cfg = Config.read(path.join(model_dir, 'deps'), 'config')
parser = Parser.from_dir(path.join(model_dir, 'deps'), vocab.strings, ArcEager)
return cls(vocab, tokenizer, tagger, parser)
def __init__(self, vocab, tokenizer, tagger, parser):
self.vocab = vocab
self.tokenizer = tokenizer
self.tagger = tagger
self.parser = parser
def train(self, words, tags, heads, deps):
tokens = self.tokenizer.tokens_from_list(list(words))
self.tagger.train(tokens, tags)
tokens = self.tokenizer.tokens_from_list(list(words))
ids = range(len(words))
ner = ['O'] * len(words)
gold = GoldParse(tokens, ((ids, words, tags, heads, deps, ner)),
make_projective=False)
self.tagger(tokens)
if gold.is_projective:
try:
self.parser.train(tokens, gold)
except:
for id_, word, head, dep in zip(ids, words, heads, deps):
print(id_, word, head, dep)
raise
def __call__(self, words, tags=None):
tokens = self.tokenizer.tokens_from_list(list(words))
if tags is None:
self.tagger(tokens)
else:
self.tagger.tag_from_strings(tokens, tags)
self.parser(tokens)
return tokens
def end_training(self, data_dir):
self.parser.model.end_training()
self.parser.model.dump(path.join(data_dir, 'deps', 'model'))
self.tagger.model.end_training()
self.tagger.model.dump(path.join(data_dir, 'pos', 'model'))
strings_loc = path.join(data_dir, 'vocab', 'strings.json')
with io.open(strings_loc, 'w', encoding='utf8') as file_:
self.vocab.strings.dump(file_)
self.vocab.dump(path.join(data_dir, 'vocab', 'lexemes.bin'))
def read_conllx(loc): def read_conllx(loc):
with open(loc, 'r', 'utf8') as file_: with open(loc, 'r', 'utf8') as file_:
text = file_.read() text = file_.read()
@ -113,24 +34,30 @@ def read_conllx(loc):
lines.pop(0) lines.pop(0)
tokens = [] tokens = []
for line in lines: for line in lines:
id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split() id_, word, lemma, tag, pos, morph, head, dep, _1, _2 = line.split()
if '-' in id_: if '-' in id_:
continue continue
id_ = int(id_) - 1 try:
head = (int(head) - 1) if head != '0' else id_ id_ = int(id_) - 1
dep = 'ROOT' if dep == 'root' else dep head = (int(head) - 1) if head != '0' else id_
tokens.append((id_, word, tag, head, dep, 'O')) dep = 'ROOT' if dep == 'root' else dep
tuples = zip(*tokens) tokens.append((id_, word, tag, head, dep, 'O'))
yield (None, [(tuples, [])]) except:
print(line)
raise
tuples = [list(t) for t in zip(*tokens)]
yield (None, [[tuples, []]])
def score_model(nlp, gold_docs, verbose=False): def score_model(vocab, tagger, parser, gold_docs, verbose=False):
scorer = Scorer() scorer = Scorer()
for _, gold_doc in gold_docs: for _, gold_doc in gold_docs:
for annot_tuples, _ in gold_doc: for (ids, words, tags, heads, deps, entities), _ in gold_doc:
tokens = nlp(list(annot_tuples[1]), tags=list(annot_tuples[2])) doc = Doc(vocab, words=words)
gold = GoldParse(tokens, annot_tuples) tagger(doc)
scorer.score(tokens, gold, verbose=verbose) parser(doc)
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
scorer.score(doc, gold, verbose=verbose)
return scorer return scorer
@ -138,22 +65,37 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
with open(tag_map_loc) as file_: with open(tag_map_loc) as file_:
tag_map = json.loads(file_.read()) tag_map = json.loads(file_.read())
train_sents = list(read_conllx(train_loc)) train_sents = list(read_conllx(train_loc))
labels = ArcEager.get_labels(train_sents) train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
templates = get_templates('basic') actions = ArcEager.get_actions(gold_parses=train_sents)
features = get_templates('basic')
TreebankParser.setup_model_dir(model_dir, labels, templates) vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map)
# Populate vocab
for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
for word in words:
_ = vocab[word]
for tag in tags:
assert tag in tag_map, repr(tag)
print(tags)
tagger = Tagger(vocab, tag_map=tag_map)
parser = DependencyParser(vocab, actions=actions, features=features)
nlp = TreebankParser.from_dir(tag_map, model_dir)
for itn in range(15): for itn in range(15):
for _, doc_sents in train_sents: for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents: for (ids, words, tags, heads, deps, ner), _ in doc_sents:
nlp.train(words, tags, heads, deps) doc = Doc(vocab, words=words)
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
tagger(doc)
parser.update(doc, gold)
doc = Doc(vocab, words=words)
tagger.update(doc, gold)
random.shuffle(train_sents) random.shuffle(train_sents)
scorer = score_model(nlp, read_conllx(dev_loc)) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc)) print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
nlp = Language(vocab=vocab, tagger=tagger, parser=parser)
nlp.end_training(model_dir) nlp.end_training(model_dir)
scorer = score_model(nlp, read_conllx(dev_loc)) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc)) print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))