From 6ea1601e9366e8cc1a135e6225b5746023510ba2 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 8 Oct 2015 12:00:11 +1100 Subject: [PATCH] * Add script to train models off the UD treebanks. Note that the UD data is restricted to research purposes only, and should only be used to train models for academic experiments. --- bin/parser/train_ud.py | 151 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 bin/parser/train_ud.py diff --git a/bin/parser/train_ud.py b/bin/parser/train_ud.py new file mode 100644 index 000000000..cd938de1f --- /dev/null +++ b/bin/parser/train_ud.py @@ -0,0 +1,151 @@ +import plac +import json +from os import path +import shutil +import os +import random + +from spacy.syntax.util import Config +from spacy.gold import GoldParse +from spacy.tokenizer import Tokenizer +from spacy.vocab import Vocab +from spacy.tagger import Tagger +from spacy.syntax.parser import Parser +from spacy.syntax.arc_eager import ArcEager +from spacy.syntax.parser import get_templates +from spacy.scorer import Scorer + +from spacy.language import Language + +from spacy.tagger import W_orth + +TAGGER_TEMPLATES = ( + (W_orth,), +) + +try: + from codecs import open +except ImportError: + pass + + +class TreebankParser(object): + @staticmethod + def setup_model_dir(model_dir, labels, templates, feat_set='basic', seed=0): + dep_model_dir = path.join(model_dir, 'deps') + pos_model_dir = path.join(model_dir, 'pos') + if path.exists(dep_model_dir): + shutil.rmtree(dep_model_dir) + if path.exists(pos_model_dir): + shutil.rmtree(pos_model_dir) + os.mkdir(dep_model_dir) + os.mkdir(pos_model_dir) + + Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, + labels=labels) + + @classmethod + def from_dir(cls, tag_map, model_dir): + vocab = Vocab(tag_map=tag_map, get_lex_attr=Language.default_lex_attrs()) + tokenizer = Tokenizer(vocab, {}, None, None, None) + tagger = Tagger.blank(vocab, TAGGER_TEMPLATES) + + cfg = Config.read(path.join(model_dir, 'deps'), 'config') + parser = Parser.from_dir(path.join(model_dir, 'deps'), vocab.strings, ArcEager) + return cls(vocab, tokenizer, tagger, parser) + + def __init__(self, vocab, tokenizer, tagger, parser): + self.vocab = vocab + self.tokenizer = tokenizer + self.tagger = tagger + self.parser = parser + + def train(self, words, tags, heads, deps): + tokens = self.tokenizer.tokens_from_list(list(words)) + self.tagger.train(tokens, tags) + + tokens = self.tokenizer.tokens_from_list(list(words)) + ids = range(len(words)) + ner = ['O'] * len(words) + gold = GoldParse(tokens, ((ids, words, tags, heads, deps, ner)), + make_projective=False) + self.tagger(tokens) + if gold.is_projective: + try: + self.parser.train(tokens, gold) + except: + for id_, word, head, dep in zip(ids, words, heads, deps): + print(id_, word, head, dep) + raise + + def __call__(self, words, tags=None): + tokens = self.tokenizer.tokens_from_list(list(words)) + if tags is None: + self.tagger(tokens) + else: + self.tagger.tag_from_strings(tokens, tags) + self.parser(tokens) + return tokens + + def end_training(self, data_dir): + self.parser.model.end_training(path.join(data_dir, 'deps', 'model')) + self.tagger.model.end_training(path.join(data_dir, 'pos', 'model')) + self.vocab.strings.dump(path.join(data_dir, 'vocab', 'strings.txt')) + + +def read_conllx(loc): + with open(loc, 'r', 'utf8') as file_: + text = file_.read() + for sent in text.strip().split('\n\n'): + lines = sent.strip().split('\n') + if lines: + if lines[0].startswith('#'): + lines.pop(0) + tokens = [] + for line in lines: + id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split() + if '-' in id_: + continue + id_ = int(id_) - 1 + head = (int(head) - 1) if head != '0' else id_ + dep = 'ROOT' if dep == 'root' else dep + tokens.append((id_, word, tag, head, dep, 'O')) + tuples = zip(*tokens) + yield (None, [(tuples, [])]) + + +def score_model(nlp, gold_docs, verbose=False): + scorer = Scorer() + for _, gold_doc in gold_docs: + for annot_tuples, _ in gold_doc: + tokens = nlp(list(annot_tuples[1]), tags=list(annot_tuples[2])) + gold = GoldParse(tokens, annot_tuples) + scorer.score(tokens, gold, verbose=verbose) + return scorer + + +def main(train_loc, dev_loc, model_dir, tag_map_loc): + with open(tag_map_loc) as file_: + tag_map = json.loads(file_.read()) + train_sents = list(read_conllx(train_loc)) + labels = ArcEager.get_labels(train_sents) + templates = get_templates('basic') + + TreebankParser.setup_model_dir(model_dir, labels, templates) + + nlp = TreebankParser.from_dir(tag_map, model_dir) + + for itn in range(15): + for _, doc_sents in train_sents: + for (ids, words, tags, heads, deps, ner), _ in doc_sents: + nlp.train(words, tags, heads, deps) + random.shuffle(train_sents) + scorer = score_model(nlp, read_conllx(dev_loc)) + print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc)) + nlp.end_training(model_dir) + scorer = score_model(nlp, read_conllx(dev_loc)) + print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc)) + + +if __name__ == '__main__': + plac.call(main)