mirror of https://github.com/explosion/spaCy.git
Fix train_ud script, which trains models from the Universal Dependencies format.
This commit is contained in:
parent
6dd3b94fa6
commit
da5f0cce36
|
@ -6,103 +6,24 @@ import os
|
||||||
import random
|
import random
|
||||||
import io
|
import io
|
||||||
|
|
||||||
from spacy.syntax.util import Config
|
from spacy.tokens import Doc
|
||||||
|
from spacy.syntax.nonproj import PseudoProjectivity
|
||||||
|
from spacy.language import Language
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from spacy.tokenizer import Tokenizer
|
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.tagger import Tagger
|
from spacy.tagger import Tagger
|
||||||
from spacy.syntax.parser import Parser
|
from spacy.pipeline import DependencyParser
|
||||||
from spacy.syntax.arc_eager import ArcEager
|
|
||||||
from spacy.syntax.parser import get_templates
|
from spacy.syntax.parser import get_templates
|
||||||
|
from spacy.syntax.arc_eager import ArcEager
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
import spacy.attrs
|
import spacy.attrs
|
||||||
|
|
||||||
from spacy.language import Language
|
|
||||||
|
|
||||||
from spacy.tagger import W_orth
|
|
||||||
|
|
||||||
TAGGER_TEMPLATES = (
|
|
||||||
(W_orth,),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from codecs import open
|
from codecs import open
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TreebankParser(object):
|
|
||||||
@staticmethod
|
|
||||||
def setup_model_dir(model_dir, labels, templates, feat_set='basic', seed=0):
|
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
|
||||||
if path.exists(dep_model_dir):
|
|
||||||
shutil.rmtree(dep_model_dir)
|
|
||||||
if path.exists(pos_model_dir):
|
|
||||||
shutil.rmtree(pos_model_dir)
|
|
||||||
os.mkdir(dep_model_dir)
|
|
||||||
os.mkdir(pos_model_dir)
|
|
||||||
|
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
|
||||||
labels=labels)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dir(cls, tag_map, model_dir):
|
|
||||||
vocab = Vocab(tag_map=tag_map, get_lex_attr=Language.default_lex_attrs())
|
|
||||||
vocab.get_lex_attr[spacy.attrs.LANG] = lambda _: 0
|
|
||||||
tokenizer = Tokenizer(vocab, {}, None, None, None)
|
|
||||||
tagger = Tagger.blank(vocab, TAGGER_TEMPLATES)
|
|
||||||
|
|
||||||
cfg = Config.read(path.join(model_dir, 'deps'), 'config')
|
|
||||||
parser = Parser.from_dir(path.join(model_dir, 'deps'), vocab.strings, ArcEager)
|
|
||||||
return cls(vocab, tokenizer, tagger, parser)
|
|
||||||
|
|
||||||
def __init__(self, vocab, tokenizer, tagger, parser):
|
|
||||||
self.vocab = vocab
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.tagger = tagger
|
|
||||||
self.parser = parser
|
|
||||||
|
|
||||||
def train(self, words, tags, heads, deps):
|
|
||||||
tokens = self.tokenizer.tokens_from_list(list(words))
|
|
||||||
self.tagger.train(tokens, tags)
|
|
||||||
|
|
||||||
tokens = self.tokenizer.tokens_from_list(list(words))
|
|
||||||
ids = range(len(words))
|
|
||||||
ner = ['O'] * len(words)
|
|
||||||
gold = GoldParse(tokens, ((ids, words, tags, heads, deps, ner)),
|
|
||||||
make_projective=False)
|
|
||||||
self.tagger(tokens)
|
|
||||||
if gold.is_projective:
|
|
||||||
try:
|
|
||||||
self.parser.train(tokens, gold)
|
|
||||||
except:
|
|
||||||
for id_, word, head, dep in zip(ids, words, heads, deps):
|
|
||||||
print(id_, word, head, dep)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def __call__(self, words, tags=None):
|
|
||||||
tokens = self.tokenizer.tokens_from_list(list(words))
|
|
||||||
if tags is None:
|
|
||||||
self.tagger(tokens)
|
|
||||||
else:
|
|
||||||
self.tagger.tag_from_strings(tokens, tags)
|
|
||||||
self.parser(tokens)
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def end_training(self, data_dir):
|
|
||||||
self.parser.model.end_training()
|
|
||||||
self.parser.model.dump(path.join(data_dir, 'deps', 'model'))
|
|
||||||
self.tagger.model.end_training()
|
|
||||||
self.tagger.model.dump(path.join(data_dir, 'pos', 'model'))
|
|
||||||
strings_loc = path.join(data_dir, 'vocab', 'strings.json')
|
|
||||||
with io.open(strings_loc, 'w', encoding='utf8') as file_:
|
|
||||||
self.vocab.strings.dump(file_)
|
|
||||||
self.vocab.dump(path.join(data_dir, 'vocab', 'lexemes.bin'))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def read_conllx(loc):
|
def read_conllx(loc):
|
||||||
with open(loc, 'r', 'utf8') as file_:
|
with open(loc, 'r', 'utf8') as file_:
|
||||||
text = file_.read()
|
text = file_.read()
|
||||||
|
@ -113,24 +34,30 @@ def read_conllx(loc):
|
||||||
lines.pop(0)
|
lines.pop(0)
|
||||||
tokens = []
|
tokens = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split()
|
id_, word, lemma, tag, pos, morph, head, dep, _1, _2 = line.split()
|
||||||
if '-' in id_:
|
if '-' in id_:
|
||||||
continue
|
continue
|
||||||
|
try:
|
||||||
id_ = int(id_) - 1
|
id_ = int(id_) - 1
|
||||||
head = (int(head) - 1) if head != '0' else id_
|
head = (int(head) - 1) if head != '0' else id_
|
||||||
dep = 'ROOT' if dep == 'root' else dep
|
dep = 'ROOT' if dep == 'root' else dep
|
||||||
tokens.append((id_, word, tag, head, dep, 'O'))
|
tokens.append((id_, word, tag, head, dep, 'O'))
|
||||||
tuples = zip(*tokens)
|
except:
|
||||||
yield (None, [(tuples, [])])
|
print(line)
|
||||||
|
raise
|
||||||
|
tuples = [list(t) for t in zip(*tokens)]
|
||||||
|
yield (None, [[tuples, []]])
|
||||||
|
|
||||||
|
|
||||||
def score_model(nlp, gold_docs, verbose=False):
|
def score_model(vocab, tagger, parser, gold_docs, verbose=False):
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
for _, gold_doc in gold_docs:
|
for _, gold_doc in gold_docs:
|
||||||
for annot_tuples, _ in gold_doc:
|
for (ids, words, tags, heads, deps, entities), _ in gold_doc:
|
||||||
tokens = nlp(list(annot_tuples[1]), tags=list(annot_tuples[2]))
|
doc = Doc(vocab, words=words)
|
||||||
gold = GoldParse(tokens, annot_tuples)
|
tagger(doc)
|
||||||
scorer.score(tokens, gold, verbose=verbose)
|
parser(doc)
|
||||||
|
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||||
|
scorer.score(doc, gold, verbose=verbose)
|
||||||
return scorer
|
return scorer
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,22 +65,37 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
|
||||||
with open(tag_map_loc) as file_:
|
with open(tag_map_loc) as file_:
|
||||||
tag_map = json.loads(file_.read())
|
tag_map = json.loads(file_.read())
|
||||||
train_sents = list(read_conllx(train_loc))
|
train_sents = list(read_conllx(train_loc))
|
||||||
labels = ArcEager.get_labels(train_sents)
|
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
||||||
templates = get_templates('basic')
|
actions = ArcEager.get_actions(gold_parses=train_sents)
|
||||||
|
features = get_templates('basic')
|
||||||
|
|
||||||
TreebankParser.setup_model_dir(model_dir, labels, templates)
|
vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map)
|
||||||
|
# Populate vocab
|
||||||
nlp = TreebankParser.from_dir(tag_map, model_dir)
|
for _, doc_sents in train_sents:
|
||||||
|
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||||
|
for word in words:
|
||||||
|
_ = vocab[word]
|
||||||
|
for tag in tags:
|
||||||
|
assert tag in tag_map, repr(tag)
|
||||||
|
print(tags)
|
||||||
|
tagger = Tagger(vocab, tag_map=tag_map)
|
||||||
|
parser = DependencyParser(vocab, actions=actions, features=features)
|
||||||
|
|
||||||
for itn in range(15):
|
for itn in range(15):
|
||||||
for _, doc_sents in train_sents:
|
for _, doc_sents in train_sents:
|
||||||
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||||
nlp.train(words, tags, heads, deps)
|
doc = Doc(vocab, words=words)
|
||||||
|
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||||
|
tagger(doc)
|
||||||
|
parser.update(doc, gold)
|
||||||
|
doc = Doc(vocab, words=words)
|
||||||
|
tagger.update(doc, gold)
|
||||||
random.shuffle(train_sents)
|
random.shuffle(train_sents)
|
||||||
scorer = score_model(nlp, read_conllx(dev_loc))
|
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||||
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
|
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
|
||||||
|
nlp = Language(vocab=vocab, tagger=tagger, parser=parser)
|
||||||
nlp.end_training(model_dir)
|
nlp.end_training(model_dir)
|
||||||
scorer = score_model(nlp, read_conllx(dev_loc))
|
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||||
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
|
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue