diff --git a/spacy/cli/conll17_ud_eval.py b/spacy/cli/conll17_ud_eval.py new file mode 100644 index 000000000..43fbcf3fa --- /dev/null +++ b/spacy/cli/conll17_ud_eval.py @@ -0,0 +1,570 @@ +#!/usr/bin/env python + +# CoNLL 2017 UD Parsing evaluation script. +# +# Compatible with Python 2.7 and 3.2+, can be used either as a module +# or a standalone executable. +# +# Copyright 2017 Institute of Formal and Applied Linguistics (UFAL), +# Faculty of Mathematics and Physics, Charles University, Czech Republic. +# +# Changelog: +# - [02 Jan 2017] Version 0.9: Initial release +# - [25 Jan 2017] Version 0.9.1: Fix bug in LCS alignment computation +# - [10 Mar 2017] Version 1.0: Add documentation and test +# Compare HEADs correctly using aligned words +# Allow evaluation with errorneous spaces in forms +# Compare forms in LCS case insensitively +# Detect cycles and multiple root nodes +# Compute AlignedAccuracy + +# Command line usage +# ------------------ +# conll17_ud_eval.py [-v] [-w weights_file] gold_conllu_file system_conllu_file +# +# - if no -v is given, only the CoNLL17 UD Shared Task evaluation LAS metrics +# is printed +# - if -v is given, several metrics are printed (as precision, recall, F1 score, +# and in case the metric is computed on aligned words also accuracy on these): +# - Tokens: how well do the gold tokens match system tokens +# - Sentences: how well do the gold sentences match system sentences +# - Words: how well can the gold words be aligned to system words +# - UPOS: using aligned words, how well does UPOS match +# - XPOS: using aligned words, how well does XPOS match +# - Feats: using aligned words, how well does FEATS match +# - AllTags: using aligned words, how well does UPOS+XPOS+FEATS match +# - Lemmas: using aligned words, how well does LEMMA match +# - UAS: using aligned words, how well does HEAD match +# - LAS: using aligned words, how well does HEAD+DEPREL(ignoring subtypes) match +# - if weights_file is given (with lines containing deprel-weight pairs), +# one more metric is shown: +# - WeightedLAS: as LAS, but each deprel (ignoring subtypes) has different weight + +# API usage +# --------- +# - load_conllu(file) +# - loads CoNLL-U file from given file object to an internal representation +# - the file object should return str on both Python 2 and Python 3 +# - raises UDError exception if the given file cannot be loaded +# - evaluate(gold_ud, system_ud) +# - evaluate the given gold and system CoNLL-U files (loaded with load_conllu) +# - raises UDError if the concatenated tokens of gold and system file do not match +# - returns a dictionary with the metrics described above, each metrics having +# three fields: precision, recall and f1 + +# Description of token matching +# ----------------------------- +# In order to match tokens of gold file and system file, we consider the text +# resulting from concatenation of gold tokens and text resulting from +# concatenation of system tokens. These texts should match -- if they do not, +# the evaluation fails. +# +# If the texts do match, every token is represented as a range in this original +# text, and tokens are equal only if their range is the same. + +# Description of word matching +# ---------------------------- +# When matching words of gold file and system file, we first match the tokens. +# The words which are also tokens are matched as tokens, but words in multi-word +# tokens have to be handled differently. +# +# To handle multi-word tokens, we start by finding "multi-word spans". +# Multi-word span is a span in the original text such that +# - it contains at least one multi-word token +# - all multi-word tokens in the span (considering both gold and system ones) +# are completely inside the span (i.e., they do not "stick out") +# - the multi-word span is as small as possible +# +# For every multi-word span, we align the gold and system words completely +# inside this span using LCS on their FORMs. The words not intersecting +# (even partially) any multi-word span are then aligned as tokens. + + +from __future__ import division +from __future__ import print_function + +import argparse +import io +import sys +import unittest + +# CoNLL-U column names +ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10) + +# UD Error is used when raising exceptions in this module +class UDError(Exception): + pass + +# Load given CoNLL-U file into internal representation +def load_conllu(file): + # Internal representation classes + class UDRepresentation: + def __init__(self): + # Characters of all the tokens in the whole file. + # Whitespace between tokens is not included. + self.characters = [] + # List of UDSpan instances with start&end indices into `characters`. + self.tokens = [] + # List of UDWord instances. + self.words = [] + # List of UDSpan instances with start&end indices into `characters`. + self.sentences = [] + class UDSpan: + def __init__(self, start, end, characters): + self.start = start + # Note that self.end marks the first position **after the end** of span, + # so we can use characters[start:end] or range(start, end). + self.end = end + self.characters = characters + + @property + def text(self): + return ''.join(self.characters[self.start:self.end]) + + def __str__(self): + return self.text + + def __repr__(self): + return self.text + class UDWord: + def __init__(self, span, columns, is_multiword): + # Span of this word (or MWT, see below) within ud_representation.characters. + self.span = span + # 10 columns of the CoNLL-U file: ID, FORM, LEMMA,... + self.columns = columns + # is_multiword==True means that this word is part of a multi-word token. + # In that case, self.span marks the span of the whole multi-word token. + self.is_multiword = is_multiword + # Reference to the UDWord instance representing the HEAD (or None if root). + self.parent = None + # Let's ignore language-specific deprel subtypes. + self.columns[DEPREL] = columns[DEPREL].split(':')[0] + + ud = UDRepresentation() + + # Load the CoNLL-U file + index, sentence_start = 0, None + linenum = 0 + while True: + line = file.readline() + linenum += 1 + if not line: + break + line = line.rstrip("\r\n") + + # Handle sentence start boundaries + if sentence_start is None: + # Skip comments + if line.startswith("#"): + continue + # Start a new sentence + ud.sentences.append(UDSpan(index, 0, ud.characters)) + sentence_start = len(ud.words) + if not line: + # Add parent UDWord links and check there are no cycles + def process_word(word): + if word.parent == "remapping": + raise UDError("There is a cycle in a sentence") + if word.parent is None: + head = int(word.columns[HEAD]) + if head > len(ud.words) - sentence_start: + raise UDError("HEAD '{}' points outside of the sentence".format(word.columns[HEAD])) + if head: + parent = ud.words[sentence_start + head - 1] + word.parent = "remapping" + process_word(parent) + word.parent = parent + + for word in ud.words[sentence_start:]: + process_word(word) + + # Check there is a single root node + if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1: + raise UDError("There are multiple roots in a sentence") + + # End the sentence + ud.sentences[-1].end = index + sentence_start = None + continue + + # Read next token/word + columns = line.split("\t") + if len(columns) != 10: + raise UDError("The CoNLL-U line {} does not contain 10 tab-separated columns: '{}'".format(linenum, line)) + + # Skip empty nodes + if "." in columns[ID]: + continue + + # Delete spaces from FORM so gold.characters == system.characters + # even if one of them tokenizes the space. + columns[FORM] = columns[FORM].replace(" ", "") + if not columns[FORM]: + raise UDError("There is an empty FORM in the CoNLL-U file -- line %d" % linenum) + + # Save token + ud.characters.extend(columns[FORM]) + ud.tokens.append(UDSpan(index, index + len(columns[FORM]), ud.characters)) + index += len(columns[FORM]) + + # Handle multi-word tokens to save word(s) + if "-" in columns[ID]: + try: + start, end = map(int, columns[ID].split("-")) + except: + raise UDError("Cannot parse multi-word token ID '{}'".format(columns[ID])) + + for _ in range(start, end + 1): + word_line = file.readline().rstrip("\r\n") + word_columns = word_line.split("\t") + if len(word_columns) != 10: + print(columns) + raise UDError("The CoNLL-U line {} does not contain 10 tab-separated columns: '{}'".format(linenum, word_line)) + ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True)) + # Basic tokens/words + else: + try: + word_id = int(columns[ID]) + except: + raise UDError("Cannot parse word ID '{}'".format(columns[ID])) + if word_id != len(ud.words) - sentence_start + 1: + raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format(columns[ID], columns[FORM], len(ud.words) - sentence_start + 1)) + + try: + head_id = int(columns[HEAD]) + except: + raise UDError("Cannot parse HEAD '{}'".format(columns[HEAD])) + if head_id < 0: + raise UDError("HEAD cannot be negative") + + ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False)) + + if sentence_start is not None: + raise UDError("The CoNLL-U file does not end with empty line") + + return ud + +# Evaluate the gold and system treebanks (loaded using load_conllu). +def evaluate(gold_ud, system_ud, deprel_weights=None): + class Score: + def __init__(self, gold_total, system_total, correct, aligned_total=None): + self.precision = correct / system_total if system_total else 0.0 + self.recall = correct / gold_total if gold_total else 0.0 + self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0 + self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total + class AlignmentWord: + def __init__(self, gold_word, system_word): + self.gold_word = gold_word + self.system_word = system_word + self.gold_parent = None + self.system_parent_gold_aligned = None + class Alignment: + def __init__(self, gold_words, system_words): + self.gold_words = gold_words + self.system_words = system_words + self.matched_words = [] + self.matched_words_map = {} + def append_aligned_words(self, gold_word, system_word): + self.matched_words.append(AlignmentWord(gold_word, system_word)) + self.matched_words_map[system_word] = gold_word + def fill_parents(self): + # We represent root parents in both gold and system data by '0'. + # For gold data, we represent non-root parent by corresponding gold word. + # For system data, we represent non-root parent by either gold word aligned + # to parent system nodes, or by None if no gold words is aligned to the parent. + for words in self.matched_words: + words.gold_parent = words.gold_word.parent if words.gold_word.parent is not None else 0 + words.system_parent_gold_aligned = self.matched_words_map.get(words.system_word.parent, None) \ + if words.system_word.parent is not None else 0 + + def lower(text): + if sys.version_info < (3, 0) and isinstance(text, str): + return text.decode("utf-8").lower() + return text.lower() + + def spans_score(gold_spans, system_spans): + correct, gi, si = 0, 0, 0 + while gi < len(gold_spans) and si < len(system_spans): + if system_spans[si].start < gold_spans[gi].start: + si += 1 + elif gold_spans[gi].start < system_spans[si].start: + gi += 1 + else: + correct += gold_spans[gi].end == system_spans[si].end + si += 1 + gi += 1 + + return Score(len(gold_spans), len(system_spans), correct) + + def alignment_score(alignment, key_fn, weight_fn=lambda w: 1): + gold, system, aligned, correct = 0, 0, 0, 0 + + for word in alignment.gold_words: + gold += weight_fn(word) + + for word in alignment.system_words: + system += weight_fn(word) + + for words in alignment.matched_words: + aligned += weight_fn(words.gold_word) + + if key_fn is None: + # Return score for whole aligned words + return Score(gold, system, aligned) + + for words in alignment.matched_words: + if key_fn(words.gold_word, words.gold_parent) == key_fn(words.system_word, words.system_parent_gold_aligned): + correct += weight_fn(words.gold_word) + + return Score(gold, system, correct, aligned) + + def beyond_end(words, i, multiword_span_end): + if i >= len(words): + return True + if words[i].is_multiword: + return words[i].span.start >= multiword_span_end + return words[i].span.end > multiword_span_end + + def extend_end(word, multiword_span_end): + if word.is_multiword and word.span.end > multiword_span_end: + return word.span.end + return multiword_span_end + + def find_multiword_span(gold_words, system_words, gi, si): + # We know gold_words[gi].is_multiword or system_words[si].is_multiword. + # Find the start of the multiword span (gs, ss), so the multiword span is minimal. + # Initialize multiword_span_end characters index. + if gold_words[gi].is_multiword: + multiword_span_end = gold_words[gi].span.end + if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start: + si += 1 + else: # if system_words[si].is_multiword + multiword_span_end = system_words[si].span.end + if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start: + gi += 1 + gs, ss = gi, si + + # Find the end of the multiword span + # (so both gi and si are pointing to the word following the multiword span end). + while not beyond_end(gold_words, gi, multiword_span_end) or \ + not beyond_end(system_words, si, multiword_span_end): + if gi < len(gold_words) and (si >= len(system_words) or + gold_words[gi].span.start <= system_words[si].span.start): + multiword_span_end = extend_end(gold_words[gi], multiword_span_end) + gi += 1 + else: + multiword_span_end = extend_end(system_words[si], multiword_span_end) + si += 1 + return gs, ss, gi, si + + def compute_lcs(gold_words, system_words, gi, si, gs, ss): + lcs = [[0] * (si - ss) for i in range(gi - gs)] + for g in reversed(range(gi - gs)): + for s in reversed(range(si - ss)): + if lower(gold_words[gs + g].columns[FORM]) == lower(system_words[ss + s].columns[FORM]): + lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0) + lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0) + lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0) + return lcs + + def align_words(gold_words, system_words): + alignment = Alignment(gold_words, system_words) + + gi, si = 0, 0 + while gi < len(gold_words) and si < len(system_words): + if gold_words[gi].is_multiword or system_words[si].is_multiword: + # A: Multi-word tokens => align via LCS within the whole "multiword span". + gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si) + + if si > ss and gi > gs: + lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss) + + # Store aligned words + s, g = 0, 0 + while g < gi - gs and s < si - ss: + if lower(gold_words[gs + g].columns[FORM]) == lower(system_words[ss + s].columns[FORM]): + alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s]) + g += 1 + s += 1 + elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0): + g += 1 + else: + s += 1 + else: + # B: No multi-word token => align according to spans. + if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end): + alignment.append_aligned_words(gold_words[gi], system_words[si]) + gi += 1 + si += 1 + elif gold_words[gi].span.start <= system_words[si].span.start: + gi += 1 + else: + si += 1 + + alignment.fill_parents() + + return alignment + + # Check that underlying character sequences do match + if gold_ud.characters != system_ud.characters: + index = 0 + while gold_ud.characters[index] == system_ud.characters[index]: + index += 1 + + raise UDError( + "The concatenation of tokens in gold file and in system file differ!\n" + + "First 20 differing characters in gold file: '{}' and system file: '{}'".format( + "".join(gold_ud.characters[index:index + 20]), + "".join(system_ud.characters[index:index + 20]) + ) + ) + + # Align words + alignment = align_words(gold_ud.words, system_ud.words) + + # Compute the F1-scores + result = { + "Tokens": spans_score(gold_ud.tokens, system_ud.tokens), + "Sentences": spans_score(gold_ud.sentences, system_ud.sentences), + "Words": alignment_score(alignment, None), + "UPOS": alignment_score(alignment, lambda w, parent: w.columns[UPOS]), + "XPOS": alignment_score(alignment, lambda w, parent: w.columns[XPOS]), + "Feats": alignment_score(alignment, lambda w, parent: w.columns[FEATS]), + "AllTags": alignment_score(alignment, lambda w, parent: (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])), + "Lemmas": alignment_score(alignment, lambda w, parent: w.columns[LEMMA]), + "UAS": alignment_score(alignment, lambda w, parent: parent), + "LAS": alignment_score(alignment, lambda w, parent: (parent, w.columns[DEPREL])), + } + + # Add WeightedLAS if weights are given + if deprel_weights is not None: + def weighted_las(word): + return deprel_weights.get(word.columns[DEPREL], 1.0) + result["WeightedLAS"] = alignment_score(alignment, lambda w, parent: (parent, w.columns[DEPREL]), weighted_las) + + return result + +def load_deprel_weights(weights_file): + if weights_file is None: + return None + + deprel_weights = {} + for line in weights_file: + # Ignore comments and empty lines + if line.startswith("#") or not line.strip(): + continue + + columns = line.rstrip("\r\n").split() + if len(columns) != 2: + raise ValueError("Expected two columns in the UD Relations weights file on line '{}'".format(line)) + + deprel_weights[columns[0]] = float(columns[1]) + + return deprel_weights + +def load_conllu_file(path): + _file = open(path, mode="r", **({"encoding": "utf-8"} if sys.version_info >= (3, 0) else {})) + return load_conllu(_file) + +def evaluate_wrapper(args): + # Load CoNLL-U files + gold_ud = load_conllu_file(args.gold_file) + system_ud = load_conllu_file(args.system_file) + + # Load weights if requested + deprel_weights = load_deprel_weights(args.weights) + + return evaluate(gold_ud, system_ud, deprel_weights) + +def main(): + # Parse arguments + parser = argparse.ArgumentParser() + parser.add_argument("gold_file", type=str, + help="Name of the CoNLL-U file with the gold data.") + parser.add_argument("system_file", type=str, + help="Name of the CoNLL-U file with the predicted data.") + parser.add_argument("--weights", "-w", type=argparse.FileType("r"), default=None, + metavar="deprel_weights_file", + help="Compute WeightedLAS using given weights for Universal Dependency Relations.") + parser.add_argument("--verbose", "-v", default=0, action="count", + help="Print all metrics.") + args = parser.parse_args() + + # Use verbose if weights are supplied + if args.weights is not None and not args.verbose: + args.verbose = 1 + + # Evaluate + evaluation = evaluate_wrapper(args) + + # Print the evaluation + if not args.verbose: + print("LAS F1 Score: {:.2f}".format(100 * evaluation["LAS"].f1)) + else: + metrics = ["Tokens", "Sentences", "Words", "UPOS", "XPOS", "Feats", "AllTags", "Lemmas", "UAS", "LAS"] + if args.weights is not None: + metrics.append("WeightedLAS") + + print("Metrics | Precision | Recall | F1 Score | AligndAcc") + print("-----------+-----------+-----------+-----------+-----------") + for metric in metrics: + print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( + metric, + 100 * evaluation[metric].precision, + 100 * evaluation[metric].recall, + 100 * evaluation[metric].f1, + "{:10.2f}".format(100 * evaluation[metric].aligned_accuracy) if evaluation[metric].aligned_accuracy is not None else "" + )) + +if __name__ == "__main__": + main() + +# Tests, which can be executed with `python -m unittest conll17_ud_eval`. +class TestAlignment(unittest.TestCase): + @staticmethod + def _load_words(words): + """Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors.""" + lines, num_words = [], 0 + for w in words: + parts = w.split(" ") + if len(parts) == 1: + num_words += 1 + lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1))) + else: + lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0])) + for part in parts[1:]: + num_words += 1 + lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1))) + return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"]))) + + def _test_exception(self, gold, system): + self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system)) + + def _test_ok(self, gold, system, correct): + metrics = evaluate(self._load_words(gold), self._load_words(system)) + gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold)) + system_words = sum((max(1, len(word.split(" ")) - 1) for word in system)) + self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1), + (correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words))) + + def test_exception(self): + self._test_exception(["a"], ["b"]) + + def test_equal(self): + self._test_ok(["a"], ["a"], 1) + self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3) + + def test_equal_with_multiword(self): + self._test_ok(["abc a b c"], ["a", "b", "c"], 3) + self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4) + self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4) + self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5) + + def test_alignment(self): + self._test_ok(["abcd"], ["a", "b", "c", "d"], 0) + self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1) + self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2) + self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2) + self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4) + self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2) + self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1) diff --git a/spacy/cli/ud_train.py b/spacy/cli/ud_train.py new file mode 100644 index 000000000..bc106fb6b --- /dev/null +++ b/spacy/cli/ud_train.py @@ -0,0 +1,394 @@ +'''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes +.conllu format for development data, allowing the official scorer to be used. +''' +from __future__ import unicode_literals +import plac +import tqdm +import attr +from pathlib import Path +import re +import sys +import json + +import spacy +import spacy.util +from ..tokens import Token, Doc +from ..gold import GoldParse +from ..syntax.nonproj import projectivize +from ..matcher import Matcher +from collections import defaultdict, Counter +from timeit import default_timer as timer + +import itertools +import random +import numpy.random +import cytoolz + +from . import conll17_ud_eval + +from .. import lang +from .. import lang +from ..lang import zh +from ..lang import ja + +lang.zh.Chinese.Defaults.use_jieba = False +lang.ja.Japanese.Defaults.use_janome = False + +random.seed(0) +numpy.random.seed(0) + +def minibatch_by_words(items, size=5000): + random.shuffle(items) + if isinstance(size, int): + size_ = itertools.repeat(size) + else: + size_ = size + items = iter(items) + while True: + batch_size = next(size_) + batch = [] + while batch_size >= 0: + try: + doc, gold = next(items) + except StopIteration: + if batch: + yield batch + return + batch_size -= len(doc) + batch.append((doc, gold)) + if batch: + yield batch + else: + break + +################ +# Data reading # +################ + +space_re = re.compile('\s+') +def split_text(text): + return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')] + + +def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, + max_doc_length=None, limit=None): + '''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True, + include Doc objects created using nlp.make_doc and then aligned against + the gold-standard sequences. If oracle_segments=True, include Doc objects + created from the gold-standard segments. At least one must be True.''' + if not raw_text and not oracle_segments: + raise ValueError("At least one of raw_text or oracle_segments must be True") + paragraphs = split_text(text_file.read()) + conllu = read_conllu(conllu_file) + # sd is spacy doc; cd is conllu doc + # cs is conllu sent, ct is conllu token + docs = [] + golds = [] + for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)): + sent_annots = [] + for cs in cd: + sent = defaultdict(list) + for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs: + if '.' in id_: + continue + if '-' in id_: + continue + id_ = int(id_)-1 + head = int(head)-1 if head != '0' else id_ + sent['words'].append(word) + sent['tags'].append(tag) + sent['heads'].append(head) + sent['deps'].append('ROOT' if dep == 'root' else dep) + sent['spaces'].append(space_after == '_') + sent['entities'] = ['-'] * len(sent['words']) + sent['heads'], sent['deps'] = projectivize(sent['heads'], + sent['deps']) + if oracle_segments: + docs.append(Doc(nlp.vocab, words=sent['words'], spaces=sent['spaces'])) + golds.append(GoldParse(docs[-1], **sent)) + + sent_annots.append(sent) + if raw_text and max_doc_length and len(sent_annots) >= max_doc_length: + doc, gold = _make_gold(nlp, None, sent_annots) + sent_annots = [] + docs.append(doc) + golds.append(gold) + if limit and len(docs) >= limit: + return docs, golds + + if raw_text and sent_annots: + doc, gold = _make_gold(nlp, None, sent_annots) + docs.append(doc) + golds.append(gold) + if limit and len(docs) >= limit: + return docs, golds + return docs, golds + + +def read_conllu(file_): + docs = [] + sent = [] + doc = [] + for line in file_: + if line.startswith('# newdoc'): + if doc: + docs.append(doc) + doc = [] + elif line.startswith('#'): + continue + elif not line.strip(): + if sent: + doc.append(sent) + sent = [] + else: + sent.append(list(line.strip().split('\t'))) + if len(sent[-1]) != 10: + print(repr(line)) + raise ValueError + if sent: + doc.append(sent) + if doc: + docs.append(doc) + return docs + + +def _make_gold(nlp, text, sent_annots): + # Flatten the conll annotations, and adjust the head indices + flat = defaultdict(list) + for sent in sent_annots: + flat['heads'].extend(len(flat['words'])+head for head in sent['heads']) + for field in ['words', 'tags', 'deps', 'entities', 'spaces']: + flat[field].extend(sent[field]) + # Construct text if necessary + assert len(flat['words']) == len(flat['spaces']) + if text is None: + text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces'])) + doc = nlp.make_doc(text) + flat.pop('spaces') + gold = GoldParse(doc, **flat) + return doc, gold + +############################# +# Data transforms for spaCy # +############################# + +def golds_to_gold_tuples(docs, golds): + '''Get out the annoying 'tuples' format used by begin_training, given the + GoldParse objects.''' + tuples = [] + for doc, gold in zip(docs, golds): + text = doc.text + ids, words, tags, heads, labels, iob = zip(*gold.orig_annot) + sents = [((ids, words, tags, heads, labels, iob), [])] + tuples.append((text, sents)) + return tuples + + +############## +# Evaluation # +############## + +def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None): + with text_loc.open('r', encoding='utf8') as text_file: + texts = split_text(text_file.read()) + docs = list(nlp.pipe(texts)) + with sys_loc.open('w', encoding='utf8') as out_file: + write_conllu(docs, out_file) + with gold_loc.open('r', encoding='utf8') as gold_file: + gold_ud = conll17_ud_eval.load_conllu(gold_file) + with sys_loc.open('r', encoding='utf8') as sys_file: + sys_ud = conll17_ud_eval.load_conllu(sys_file) + scores = conll17_ud_eval.evaluate(gold_ud, sys_ud) + return scores + + +def write_conllu(docs, file_): + merger = Matcher(docs[0].vocab) + merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}]) + for i, doc in enumerate(docs): + matches = merger(doc) + spans = [doc[start:end+1] for _, start, end in matches] + offsets = [(span.start_char, span.end_char) for span in spans] + for start_char, end_char in offsets: + doc.merge(start_char, end_char) + file_.write("# newdoc id = {i}\n".format(i=i)) + for j, sent in enumerate(doc.sents): + file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j)) + file_.write("# text = {text}\n".format(text=sent.text)) + for k, token in enumerate(sent): + file_.write(token._.get_conllu_lines(k) + '\n') + file_.write('\n') + + +def print_progress(itn, losses, ud_scores): + fields = { + 'dep_loss': losses.get('parser', 0.0), + 'tag_loss': losses.get('tagger', 0.0), + 'words': ud_scores['Words'].f1 * 100, + 'sents': ud_scores['Sentences'].f1 * 100, + 'tags': ud_scores['XPOS'].f1 * 100, + 'uas': ud_scores['UAS'].f1 * 100, + 'las': ud_scores['LAS'].f1 * 100, + } + header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD'] + if itn == 0: + print('\t'.join(header)) + tpl = '\t'.join(( + '{:d}', + '{dep_loss:.1f}', + '{las:.1f}', + '{uas:.1f}', + '{tags:.1f}', + '{sents:.1f}', + '{words:.1f}', + )) + print(tpl.format(itn, **fields)) + +#def get_sent_conllu(sent, sent_id): +# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)] + +def get_token_conllu(token, i): + if token._.begins_fused: + n = 1 + while token.nbor(n)._.inside_fused: + n += 1 + id_ = '%d-%d' % (i, i+n) + lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_'] + else: + lines = [] + if token.head.i == token.i: + head = 0 + else: + head = i + (token.head.i - token.i) + 1 + fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_', + str(head), token.dep_.lower(), '_', '_'] + lines.append('\t'.join(fields)) + return '\n'.join(lines) + +Token.set_extension('get_conllu_lines', method=get_token_conllu) +Token.set_extension('begins_fused', default=False) +Token.set_extension('inside_fused', default=False) + + +################## +# Initialization # +################## + + +def load_nlp(corpus, config): + lang = corpus.split('_')[0] + nlp = spacy.blank(lang) + if config.vectors: + nlp.vocab.from_disk(config.vectors / 'vocab') + return nlp + +def initialize_pipeline(nlp, docs, golds, config): + nlp.add_pipe(nlp.create_pipe('parser')) + if config.multitask_tag: + nlp.parser.add_multitask_objective('tag') + if config.multitask_sent: + nlp.parser.add_multitask_objective('sent_start') + nlp.parser.moves.add_action(2, 'subtok') + nlp.add_pipe(nlp.create_pipe('tagger')) + for gold in golds: + for tag in gold.tags: + if tag is not None: + nlp.tagger.add_label(tag) + # Replace labels that didn't make the frequency cutoff + actions = set(nlp.parser.labels) + label_set = set([act.split('-')[1] for act in actions if '-' in act]) + for gold in golds: + for i, label in enumerate(gold.labels): + if label is not None and label not in label_set: + gold.labels[i] = label.split('||')[0] + return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds)) + + +######################## +# Command line helpers # +######################## + +@attr.s +class Config(object): + vectors = attr.ib(default=None) + max_doc_length = attr.ib(default=10) + multitask_tag = attr.ib(default=True) + multitask_sent = attr.ib(default=True) + nr_epoch = attr.ib(default=30) + batch_size = attr.ib(default=1000) + dropout = attr.ib(default=0.2) + + @classmethod + def load(cls, loc): + with Path(loc).open('r', encoding='utf8') as file_: + cfg = json.load(file_) + return cls(**cfg) + + +class Dataset(object): + def __init__(self, path, section): + self.path = path + self.section = section + self.conllu = None + self.text = None + for file_path in self.path.iterdir(): + name = file_path.parts[-1] + if section in name and name.endswith('conllu'): + self.conllu = file_path + elif section in name and name.endswith('txt'): + self.text = file_path + if self.conllu is None: + msg = "Could not find .txt file in {path} for {section}" + raise IOError(msg.format(section=section, path=path)) + if self.text is None: + msg = "Could not find .txt file in {path} for {section}" + self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0] + + +class TreebankPaths(object): + def __init__(self, ud_path, treebank, **cfg): + self.train = Dataset(ud_path / treebank, 'train') + self.dev = Dataset(ud_path / treebank, 'dev') + self.lang = self.train.lang + + +@plac.annotations( + ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path), + corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc", + "positional", None, str), + parses_dir=("Directory to write the development parses", "positional", None, Path), + config=("Path to json formatted config file", "positional", None, Config.load), + limit=("Size limit", "option", "n", int) +) +def main(ud_dir, parses_dir, config, corpus, limit=0): + paths = TreebankPaths(ud_dir, corpus) + if not (parses_dir / corpus).exists(): + (parses_dir / corpus).mkdir() + print("Train and evaluate", corpus, "using lang", paths.lang) + nlp = load_nlp(paths.lang, config) + + docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), + max_doc_length=config.max_doc_length, limit=limit) + + optimizer = initialize_pipeline(nlp, docs, golds, config) + + for i in range(config.nr_epoch): + docs = [nlp.make_doc(doc.text) for doc in docs] + batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size) + losses = {} + n_train_words = sum(len(doc) for doc in docs) + with tqdm.tqdm(total=n_train_words, leave=False) as pbar: + for batch in batches: + batch_docs, batch_gold = zip(*batch) + pbar.update(sum(len(doc) for doc in batch_docs)) + nlp.update(batch_docs, batch_gold, sgd=optimizer, + drop=config.dropout, losses=losses) + + out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i) + with nlp.use_params(optimizer.averages): + scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path) + print_progress(i, losses, scores) + + +if __name__ == '__main__': + plac.call(main)