diff --git a/examples/training/train_parser.py b/examples/training/train_parser.py new file mode 100644 index 000000000..5850a91a5 --- /dev/null +++ b/examples/training/train_parser.py @@ -0,0 +1,81 @@ +from __future__ import unicode_literals, print_function +import json +import pathlib +import random + +import spacy +from spacy.pipeline import DependencyParser +from spacy.gold import GoldParse +from spacy.tokens import Doc + + +def train_parser(nlp, train_data, left_labels, right_labels): + parser = DependencyParser.blank( + nlp.vocab, + left_labels=left_labels, + right_labels=right_labels, + features=nlp.defaults.parser_features) + for itn in range(1000): + random.shuffle(train_data) + loss = 0 + for words, heads, deps in train_data: + doc = nlp.make_doc(words) + gold = GoldParse(doc, heads=heads, deps=deps) + loss += parser.update(doc, gold) + parser.model.end_training() + return parser + + +def main(model_dir=None): + if model_dir is not None: + model_dir = pathlb.Path(model_dir) + if not model_dir.exists(): + model_dir.mkdir() + assert model_dir.isdir() + + nlp = spacy.load('en', tagger=False, parser=False, entity=False, vectors=False) + nlp.make_doc = lambda words: Doc(nlp.vocab, zip(words, [True]*len(words))) + + train_data = [ + ( + ['They', 'trade', 'mortgage', '-', 'backed', 'securities', '.'], + [1, 1, 4, 4, 5, 1, 1], + ['nsubj', 'ROOT', 'compound', 'punct', 'nmod', 'dobj', 'punct'] + ), + ( + ['I', 'like', 'London', 'and', 'Berlin', '.'], + [1, 1, 1, 2, 2, 1], + ['nsubj', 'ROOT', 'dobj', 'cc', 'conj', 'punct'] + ) + ] + left_labels = set() + right_labels = set() + for _, heads, deps in train_data: + for i, (head, dep) in enumerate(zip(heads, deps)): + if i < head: + left_labels.add(dep) + elif i > head: + right_labels.add(dep) + parser = train_parser(nlp, train_data, sorted(left_labels), sorted(right_labels)) + + doc = nlp.make_doc(['I', 'like', 'securities', '.']) + with parser.step_through(doc) as state: + while not state.is_final: + action = state.predict() + state.transition(action) + #parser(doc) + for word in doc: + print(word.text, word.dep_, word.head.text) + + if model_dir is not None: + with (model_dir / 'config.json').open('wb') as file_: + json.dump(parser.cfg, file_) + parser.model.dump(str(model_dir / 'model')) + + +if __name__ == '__main__': + main() + # I nsubj like + # like ROOT like + # securities dobj like + # . cc securities