mirror of https://github.com/explosion/spaCy.git
325 lines
9.9 KiB
Python
325 lines
9.9 KiB
Python
# flake8: noqa
|
|
"""Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
|
|
.conllu format for development data, allowing the official scorer to be used.
|
|
"""
|
|
from __future__ import unicode_literals
|
|
|
|
import plac
|
|
from pathlib import Path
|
|
import re
|
|
import sys
|
|
import srsly
|
|
|
|
import spacy
|
|
import spacy.util
|
|
from spacy.tokens import Token, Doc
|
|
from spacy.matcher import Matcher
|
|
|
|
|
|
Fused_begin = None
|
|
Fused_inside = None
|
|
|
|
from . import conll17_ud_eval
|
|
|
|
from spacy import lang
|
|
from spacy.lang import zh
|
|
from spacy.lang import ja
|
|
from spacy.lang import ru
|
|
|
|
|
|
################
|
|
# Data reading #
|
|
################
|
|
|
|
space_re = re.compile(r"\s+")
|
|
|
|
|
|
def split_text(text):
|
|
return [space_re.sub(" ", par.strip()) for par in text.split("\n\n")]
|
|
|
|
|
|
##############
|
|
# Evaluation #
|
|
##############
|
|
|
|
|
|
def read_conllu(file_):
|
|
docs = []
|
|
sent = []
|
|
doc = []
|
|
for line in file_:
|
|
if line.startswith("# newdoc"):
|
|
if doc:
|
|
docs.append(doc)
|
|
doc = []
|
|
elif line.startswith("#"):
|
|
continue
|
|
elif not line.strip():
|
|
if sent:
|
|
doc.append(sent)
|
|
sent = []
|
|
else:
|
|
sent.append(list(line.strip().split("\t")))
|
|
if len(sent[-1]) != 10:
|
|
print(repr(line))
|
|
raise ValueError
|
|
if sent:
|
|
doc.append(sent)
|
|
if doc:
|
|
docs.append(doc)
|
|
return docs
|
|
|
|
|
|
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
|
if text_loc.parts[-1].endswith(".conllu"):
|
|
docs = []
|
|
with text_loc.open(encoding="utf8") as file_:
|
|
for conllu_doc in read_conllu(file_):
|
|
for conllu_sent in conllu_doc:
|
|
words = [line[1] for line in conllu_sent]
|
|
docs.append(Doc(nlp.vocab, words=words))
|
|
for name, component in nlp.pipeline:
|
|
docs = list(component.pipe(docs))
|
|
else:
|
|
with text_loc.open("r", encoding="utf8") as text_file:
|
|
texts = split_text(text_file.read())
|
|
docs = list(nlp.pipe(texts))
|
|
with sys_loc.open("w", encoding="utf8") as out_file:
|
|
write_conllu(docs, out_file)
|
|
with gold_loc.open("r", encoding="utf8") as gold_file:
|
|
gold_ud = conll17_ud_eval.load_conllu(gold_file)
|
|
with sys_loc.open("r", encoding="utf8") as sys_file:
|
|
sys_ud = conll17_ud_eval.load_conllu(sys_file)
|
|
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
|
|
return docs, scores
|
|
|
|
|
|
def write_conllu(docs, file_):
|
|
merger = Matcher(docs[0].vocab)
|
|
merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}])
|
|
for i, doc in enumerate(docs):
|
|
matches = []
|
|
if doc.is_parsed:
|
|
matches = merger(doc)
|
|
spans = [doc[start : end + 1] for _, start, end in matches]
|
|
with doc.retokenize() as retokenizer:
|
|
for span in spans:
|
|
retokenizer.merge(span)
|
|
file_.write("# newdoc id = {i}\n".format(i=i))
|
|
for j, sent in enumerate(doc.sents):
|
|
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
|
file_.write("# text = {text}\n".format(text=sent.text))
|
|
for k, token in enumerate(sent):
|
|
file_.write(_get_token_conllu(token, k, len(sent)) + "\n")
|
|
file_.write("\n")
|
|
for word in sent:
|
|
if word.head.i == word.i and word.dep_ == "ROOT":
|
|
break
|
|
else:
|
|
print("Rootless sentence!")
|
|
print(sent)
|
|
print(i)
|
|
for w in sent:
|
|
print(w.i, w.text, w.head.text, w.head.i, w.dep_)
|
|
raise ValueError
|
|
|
|
|
|
def _get_token_conllu(token, k, sent_len):
|
|
if token.check_morph(Fused_begin) and (k + 1 < sent_len):
|
|
n = 1
|
|
text = [token.text]
|
|
while token.nbor(n).check_morph(Fused_inside):
|
|
text.append(token.nbor(n).text)
|
|
n += 1
|
|
id_ = "%d-%d" % (k + 1, (k + n))
|
|
fields = [id_, "".join(text)] + ["_"] * 8
|
|
lines = ["\t".join(fields)]
|
|
else:
|
|
lines = []
|
|
if token.head.i == token.i:
|
|
head = 0
|
|
else:
|
|
head = k + (token.head.i - token.i) + 1
|
|
fields = [
|
|
str(k + 1),
|
|
token.text,
|
|
token.lemma_,
|
|
token.pos_,
|
|
token.tag_,
|
|
"_",
|
|
str(head),
|
|
token.dep_.lower(),
|
|
"_",
|
|
"_",
|
|
]
|
|
if token.check_morph(Fused_begin) and (k + 1 < sent_len):
|
|
if k == 0:
|
|
fields[1] = token.norm_[0].upper() + token.norm_[1:]
|
|
else:
|
|
fields[1] = token.norm_
|
|
elif token.check_morph(Fused_inside):
|
|
fields[1] = token.norm_
|
|
elif token._.split_start is not None:
|
|
split_start = token._.split_start
|
|
split_end = token._.split_end
|
|
split_len = (split_end.i - split_start.i) + 1
|
|
n_in_split = token.i - split_start.i
|
|
subtokens = guess_fused_orths(split_start.text, [""] * split_len)
|
|
fields[1] = subtokens[n_in_split]
|
|
|
|
lines.append("\t".join(fields))
|
|
return "\n".join(lines)
|
|
|
|
|
|
def guess_fused_orths(word, ud_forms):
|
|
"""The UD data 'fused tokens' don't necessarily expand to keys that match
|
|
the form. We need orths that exact match the string. Here we make a best
|
|
effort to divide up the word."""
|
|
if word == "".join(ud_forms):
|
|
# Happy case: we get a perfect split, with each letter accounted for.
|
|
return ud_forms
|
|
elif len(word) == sum(len(subtoken) for subtoken in ud_forms):
|
|
# Unideal, but at least lengths match.
|
|
output = []
|
|
remain = word
|
|
for subtoken in ud_forms:
|
|
assert len(subtoken) >= 1
|
|
output.append(remain[: len(subtoken)])
|
|
remain = remain[len(subtoken) :]
|
|
assert len(remain) == 0, (word, ud_forms, remain)
|
|
return output
|
|
else:
|
|
# Let's say word is 6 long, and there are three subtokens. The orths
|
|
# *must* equal the original string. Arbitrarily, split [4, 1, 1]
|
|
first = word[: len(word) - (len(ud_forms) - 1)]
|
|
output = [first]
|
|
remain = word[len(first) :]
|
|
for i in range(1, len(ud_forms)):
|
|
assert remain
|
|
output.append(remain[:1])
|
|
remain = remain[1:]
|
|
assert len(remain) == 0, (word, output, remain)
|
|
return output
|
|
|
|
|
|
def print_results(name, ud_scores):
|
|
fields = {}
|
|
if ud_scores is not None:
|
|
fields.update(
|
|
{
|
|
"words": ud_scores["Words"].f1 * 100,
|
|
"sents": ud_scores["Sentences"].f1 * 100,
|
|
"tags": ud_scores["XPOS"].f1 * 100,
|
|
"uas": ud_scores["UAS"].f1 * 100,
|
|
"las": ud_scores["LAS"].f1 * 100,
|
|
}
|
|
)
|
|
else:
|
|
fields.update({"words": 0.0, "sents": 0.0, "tags": 0.0, "uas": 0.0, "las": 0.0})
|
|
tpl = "\t".join(
|
|
(name, "{las:.1f}", "{uas:.1f}", "{tags:.1f}", "{sents:.1f}", "{words:.1f}")
|
|
)
|
|
print(tpl.format(**fields))
|
|
return fields
|
|
|
|
|
|
def get_token_split_start(token):
|
|
if token.text == "":
|
|
assert token.i != 0
|
|
i = -1
|
|
while token.nbor(i).text == "":
|
|
i -= 1
|
|
return token.nbor(i)
|
|
elif (token.i + 1) < len(token.doc) and token.nbor(1).text == "":
|
|
return token
|
|
else:
|
|
return None
|
|
|
|
|
|
def get_token_split_end(token):
|
|
if (token.i + 1) == len(token.doc):
|
|
return token if token.text == "" else None
|
|
elif token.text != "" and token.nbor(1).text != "":
|
|
return None
|
|
i = 1
|
|
while (token.i + i) < len(token.doc) and token.nbor(i).text == "":
|
|
i += 1
|
|
return token.nbor(i - 1)
|
|
|
|
|
|
##################
|
|
# Initialization #
|
|
##################
|
|
|
|
|
|
def load_nlp(experiments_dir, corpus):
|
|
nlp = spacy.load(experiments_dir / corpus / "best-model")
|
|
return nlp
|
|
|
|
|
|
def initialize_pipeline(nlp, examples, config, device):
|
|
nlp.add_pipe(nlp.create_pipe("parser"))
|
|
return nlp
|
|
|
|
|
|
@plac.annotations(
|
|
test_data_dir=(
|
|
"Path to Universal Dependencies test data",
|
|
"positional",
|
|
None,
|
|
Path,
|
|
),
|
|
experiment_dir=("Parent directory with output model", "positional", None, Path),
|
|
corpus=(
|
|
"UD corpus to evaluate, e.g. UD_English, UD_Spanish, etc",
|
|
"positional",
|
|
None,
|
|
str,
|
|
),
|
|
)
|
|
def main(test_data_dir, experiment_dir, corpus):
|
|
Token.set_extension("split_start", getter=get_token_split_start)
|
|
Token.set_extension("split_end", getter=get_token_split_end)
|
|
Token.set_extension("begins_fused", default=False)
|
|
Token.set_extension("inside_fused", default=False)
|
|
lang.zh.Chinese.Defaults.use_jieba = False
|
|
lang.ja.Japanese.Defaults.use_janome = False
|
|
lang.ru.Russian.Defaults.use_pymorphy2 = False
|
|
|
|
nlp = load_nlp(experiment_dir, corpus)
|
|
|
|
treebank_code = nlp.meta["treebank"]
|
|
for section in ("test", "dev"):
|
|
if section == "dev":
|
|
section_dir = "conll17-ud-development-2017-03-19"
|
|
else:
|
|
section_dir = "conll17-ud-test-2017-05-09"
|
|
text_path = test_data_dir / "input" / section_dir / (treebank_code + ".txt")
|
|
udpipe_path = (
|
|
test_data_dir / "input" / section_dir / (treebank_code + "-udpipe.conllu")
|
|
)
|
|
gold_path = test_data_dir / "gold" / section_dir / (treebank_code + ".conllu")
|
|
|
|
header = [section, "LAS", "UAS", "TAG", "SENT", "WORD"]
|
|
print("\t".join(header))
|
|
inputs = {"gold": gold_path, "udp": udpipe_path, "raw": text_path}
|
|
for input_type in ("udp", "raw"):
|
|
input_path = inputs[input_type]
|
|
output_path = (
|
|
experiment_dir / corpus / "{section}.conllu".format(section=section)
|
|
)
|
|
|
|
parsed_docs, test_scores = evaluate(nlp, input_path, gold_path, output_path)
|
|
|
|
accuracy = print_results(input_type, test_scores)
|
|
acc_path = (
|
|
experiment_dir
|
|
/ corpus
|
|
/ "{section}-accuracy.json".format(section=section)
|
|
)
|
|
srsly.write_json(acc_path, accuracy)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
plac.call(main)
|