mirror of https://github.com/explosion/spaCy.git
Add a tagger-based SentenceRecognizer (#4713)
* Add sent_starts to GoldParse * Add SentTagger pipeline component Add `SentTagger` pipeline component as a subclass of `Tagger`. * Model reduces default parameters from `Tagger` to be small and fast * Hard-coded set of two labels: * S (1): token at beginning of sentence * I (0): all other sentence positions * Sets `token.sent_start` values * Add sentence segmentation to Scorer Report `sent_p/r/f` for sentence boundaries, which may be provided by various pipeline components. * Add sentence segmentation to CLI evaluate * Add senttagger metrics/scoring to train CLI * Rename SentTagger to SentenceRecognizer * Add SentenceRecognizer to spacy.pipes imports * Add SentenceRecognizer serialization test * Shorten component name to sentrec * Remove duplicates from train CLI output metrics
This commit is contained in:
parent
9efd3ccbef
commit
b841d3fe75
|
@ -61,6 +61,9 @@ def evaluate(
|
||||||
"NER R": "%.2f" % scorer.ents_r,
|
"NER R": "%.2f" % scorer.ents_r,
|
||||||
"NER F": "%.2f" % scorer.ents_f,
|
"NER F": "%.2f" % scorer.ents_f,
|
||||||
"Textcat": "%.2f" % scorer.textcat_score,
|
"Textcat": "%.2f" % scorer.textcat_score,
|
||||||
|
"Sent P": "%.2f" % scorer.sent_p,
|
||||||
|
"Sent R": "%.2f" % scorer.sent_r,
|
||||||
|
"Sent F": "%.2f" % scorer.sent_f,
|
||||||
}
|
}
|
||||||
msg.table(results, title="Results")
|
msg.table(results, title="Results")
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ import srsly
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import contextlib
|
import contextlib
|
||||||
import random
|
import random
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from .._ml import create_default_optimizer
|
from .._ml import create_default_optimizer
|
||||||
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
|
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
|
||||||
|
@ -585,11 +586,13 @@ def _find_best(experiment_dir, component):
|
||||||
|
|
||||||
def _get_metrics(component):
|
def _get_metrics(component):
|
||||||
if component == "parser":
|
if component == "parser":
|
||||||
return ("las", "uas", "token_acc")
|
return ("las", "uas", "token_acc", "sent_f")
|
||||||
elif component == "tagger":
|
elif component == "tagger":
|
||||||
return ("tags_acc",)
|
return ("tags_acc",)
|
||||||
elif component == "ner":
|
elif component == "ner":
|
||||||
return ("ents_f", "ents_p", "ents_r")
|
return ("ents_f", "ents_p", "ents_r")
|
||||||
|
elif component == "sentrec":
|
||||||
|
return ("sent_p", "sent_r", "sent_f",)
|
||||||
return ("token_acc",)
|
return ("token_acc",)
|
||||||
|
|
||||||
|
|
||||||
|
@ -601,14 +604,17 @@ def _configure_training_output(pipeline, use_gpu, has_beam_widths):
|
||||||
row_head.extend(["Tag Loss ", " Tag % "])
|
row_head.extend(["Tag Loss ", " Tag % "])
|
||||||
output_stats.extend(["tag_loss", "tags_acc"])
|
output_stats.extend(["tag_loss", "tags_acc"])
|
||||||
elif pipe == "parser":
|
elif pipe == "parser":
|
||||||
row_head.extend(["Dep Loss ", " UAS ", " LAS "])
|
row_head.extend(["Dep Loss ", " UAS ", " LAS ", "Sent P", "Sent R", "Sent F"])
|
||||||
output_stats.extend(["dep_loss", "uas", "las"])
|
output_stats.extend(["dep_loss", "uas", "las", "sent_p", "sent_r", "sent_f"])
|
||||||
elif pipe == "ner":
|
elif pipe == "ner":
|
||||||
row_head.extend(["NER Loss ", "NER P ", "NER R ", "NER F "])
|
row_head.extend(["NER Loss ", "NER P ", "NER R ", "NER F "])
|
||||||
output_stats.extend(["ner_loss", "ents_p", "ents_r", "ents_f"])
|
output_stats.extend(["ner_loss", "ents_p", "ents_r", "ents_f"])
|
||||||
elif pipe == "textcat":
|
elif pipe == "textcat":
|
||||||
row_head.extend(["Textcat Loss", "Textcat"])
|
row_head.extend(["Textcat Loss", "Textcat"])
|
||||||
output_stats.extend(["textcat_loss", "textcat_score"])
|
output_stats.extend(["textcat_loss", "textcat_score"])
|
||||||
|
elif pipe == "sentrec":
|
||||||
|
row_head.extend(["Sentrec Loss", "Sent P", "Sent R", "Sent F"])
|
||||||
|
output_stats.extend(["sentrec_loss", "sent_p", "sent_r", "sent_f"])
|
||||||
row_head.extend(["Token %", "CPU WPS"])
|
row_head.extend(["Token %", "CPU WPS"])
|
||||||
output_stats.extend(["token_acc", "cpu_wps"])
|
output_stats.extend(["token_acc", "cpu_wps"])
|
||||||
|
|
||||||
|
@ -618,7 +624,12 @@ def _configure_training_output(pipeline, use_gpu, has_beam_widths):
|
||||||
|
|
||||||
if has_beam_widths:
|
if has_beam_widths:
|
||||||
row_head.insert(1, "Beam W.")
|
row_head.insert(1, "Beam W.")
|
||||||
return row_head, output_stats
|
# remove duplicates
|
||||||
|
row_head_dict = OrderedDict()
|
||||||
|
row_head_dict.update({k: 1 for k in row_head})
|
||||||
|
output_stats_dict = OrderedDict()
|
||||||
|
output_stats_dict.update({k: 1 for k in output_stats})
|
||||||
|
return row_head_dict.keys(), output_stats_dict.keys()
|
||||||
|
|
||||||
|
|
||||||
def _get_progress(
|
def _get_progress(
|
||||||
|
@ -631,6 +642,7 @@ def _get_progress(
|
||||||
scores["ner_loss"] = losses.get("ner", 0.0)
|
scores["ner_loss"] = losses.get("ner", 0.0)
|
||||||
scores["tag_loss"] = losses.get("tagger", 0.0)
|
scores["tag_loss"] = losses.get("tagger", 0.0)
|
||||||
scores["textcat_loss"] = losses.get("textcat", 0.0)
|
scores["textcat_loss"] = losses.get("textcat", 0.0)
|
||||||
|
scores["sentrec_loss"] = losses.get("sentrec", 0.0)
|
||||||
scores["cpu_wps"] = cpu_wps
|
scores["cpu_wps"] = cpu_wps
|
||||||
scores["gpu_wps"] = gpu_wps or 0.0
|
scores["gpu_wps"] = gpu_wps or 0.0
|
||||||
scores.update(dev_scores)
|
scores.update(dev_scores)
|
||||||
|
|
|
@ -26,6 +26,7 @@ cdef class GoldParse:
|
||||||
cdef public list words
|
cdef public list words
|
||||||
cdef public list tags
|
cdef public list tags
|
||||||
cdef public list morphs
|
cdef public list morphs
|
||||||
|
cdef public list sent_starts
|
||||||
cdef public list heads
|
cdef public list heads
|
||||||
cdef public list labels
|
cdef public list labels
|
||||||
cdef public dict orths
|
cdef public dict orths
|
||||||
|
|
|
@ -497,9 +497,9 @@ def json_to_examples(doc):
|
||||||
ner.append(token.get("ner", "-"))
|
ner.append(token.get("ner", "-"))
|
||||||
morphs.append(token.get("morph", {}))
|
morphs.append(token.get("morph", {}))
|
||||||
if i == 0:
|
if i == 0:
|
||||||
sent_starts.append(True)
|
sent_starts.append(1)
|
||||||
else:
|
else:
|
||||||
sent_starts.append(False)
|
sent_starts.append(0)
|
||||||
if "brackets" in sent:
|
if "brackets" in sent:
|
||||||
brackets.extend((b["first"] + sent_start_i,
|
brackets.extend((b["first"] + sent_start_i,
|
||||||
b["last"] + sent_start_i, b["label"])
|
b["last"] + sent_start_i, b["label"])
|
||||||
|
@ -759,7 +759,7 @@ cdef class Example:
|
||||||
t = self.token_annotation
|
t = self.token_annotation
|
||||||
split_examples = []
|
split_examples = []
|
||||||
for i in range(len(t.words)):
|
for i in range(len(t.words)):
|
||||||
if i > 0 and t.sent_starts[i] == True:
|
if i > 0 and t.sent_starts[i] == 1:
|
||||||
s_example.set_token_annotation(ids=s_ids,
|
s_example.set_token_annotation(ids=s_ids,
|
||||||
words=s_words, tags=s_tags, heads=s_heads, deps=s_deps,
|
words=s_words, tags=s_tags, heads=s_heads, deps=s_deps,
|
||||||
entities=s_ents, morphs=s_morphs,
|
entities=s_ents, morphs=s_morphs,
|
||||||
|
@ -892,6 +892,7 @@ cdef class GoldParse:
|
||||||
deps=token_annotation.deps,
|
deps=token_annotation.deps,
|
||||||
entities=token_annotation.entities,
|
entities=token_annotation.entities,
|
||||||
morphs=token_annotation.morphs,
|
morphs=token_annotation.morphs,
|
||||||
|
sent_starts=token_annotation.sent_starts,
|
||||||
cats=doc_annotation.cats,
|
cats=doc_annotation.cats,
|
||||||
links=doc_annotation.links,
|
links=doc_annotation.links,
|
||||||
make_projective=make_projective)
|
make_projective=make_projective)
|
||||||
|
@ -902,12 +903,13 @@ cdef class GoldParse:
|
||||||
ids = list(range(len(self.words)))
|
ids = list(range(len(self.words)))
|
||||||
|
|
||||||
return TokenAnnotation(ids=ids, words=self.words, tags=self.tags,
|
return TokenAnnotation(ids=ids, words=self.words, tags=self.tags,
|
||||||
heads=self.heads, deps=self.labels, entities=self.ner,
|
heads=self.heads, deps=self.labels,
|
||||||
morphs=self.morphs)
|
entities=self.ner, morphs=self.morphs,
|
||||||
|
sent_starts=self.sent_starts)
|
||||||
|
|
||||||
def __init__(self, doc, words=None, tags=None, morphs=None,
|
def __init__(self, doc, words=None, tags=None, morphs=None,
|
||||||
heads=None, deps=None, entities=None, make_projective=False,
|
heads=None, deps=None, entities=None, sent_starts=None,
|
||||||
cats=None, links=None):
|
make_projective=False, cats=None, links=None):
|
||||||
"""Create a GoldParse. The fields will not be initialized if len(doc) is zero.
|
"""Create a GoldParse. The fields will not be initialized if len(doc) is zero.
|
||||||
|
|
||||||
doc (Doc): The document the annotations refer to.
|
doc (Doc): The document the annotations refer to.
|
||||||
|
@ -920,6 +922,8 @@ cdef class GoldParse:
|
||||||
entities (iterable): A sequence of named entity annotations, either as
|
entities (iterable): A sequence of named entity annotations, either as
|
||||||
BILUO tag strings, or as `(start_char, end_char, label)` tuples,
|
BILUO tag strings, or as `(start_char, end_char, label)` tuples,
|
||||||
representing the entity positions.
|
representing the entity positions.
|
||||||
|
sent_starts (iterable): A sequence of sentence position tags, 1 for
|
||||||
|
the first word in a sentence, 0 for all others.
|
||||||
cats (dict): Labels for text classification. Each key in the dictionary
|
cats (dict): Labels for text classification. Each key in the dictionary
|
||||||
may be a string or an int, or a `(start_char, end_char, label)`
|
may be a string or an int, or a `(start_char, end_char, label)`
|
||||||
tuple, indicating that the label is applied to only part of the
|
tuple, indicating that the label is applied to only part of the
|
||||||
|
@ -956,6 +960,8 @@ cdef class GoldParse:
|
||||||
deps = [None for _ in words]
|
deps = [None for _ in words]
|
||||||
if not morphs:
|
if not morphs:
|
||||||
morphs = [None for _ in words]
|
morphs = [None for _ in words]
|
||||||
|
if not sent_starts:
|
||||||
|
sent_starts = [None for _ in words]
|
||||||
if entities is None:
|
if entities is None:
|
||||||
entities = ["-" for _ in words]
|
entities = ["-" for _ in words]
|
||||||
elif len(entities) == 0:
|
elif len(entities) == 0:
|
||||||
|
@ -982,6 +988,7 @@ cdef class GoldParse:
|
||||||
self.labels = [None] * len(doc)
|
self.labels = [None] * len(doc)
|
||||||
self.ner = [None] * len(doc)
|
self.ner = [None] * len(doc)
|
||||||
self.morphs = [None] * len(doc)
|
self.morphs = [None] * len(doc)
|
||||||
|
self.sent_starts = [None] * len(doc)
|
||||||
|
|
||||||
# This needs to be done before we align the words
|
# This needs to be done before we align the words
|
||||||
if make_projective and heads is not None and deps is not None:
|
if make_projective and heads is not None and deps is not None:
|
||||||
|
@ -1000,7 +1007,7 @@ cdef class GoldParse:
|
||||||
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
|
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
|
||||||
|
|
||||||
self.orig = TokenAnnotation(ids=list(range(len(words))), words=words, tags=tags,
|
self.orig = TokenAnnotation(ids=list(range(len(words))), words=words, tags=tags,
|
||||||
heads=heads, deps=deps, entities=entities, morphs=morphs,
|
heads=heads, deps=deps, entities=entities, morphs=morphs, sent_starts=sent_starts,
|
||||||
brackets=[])
|
brackets=[])
|
||||||
|
|
||||||
for i, gold_i in enumerate(self.cand_to_gold):
|
for i, gold_i in enumerate(self.cand_to_gold):
|
||||||
|
@ -1011,11 +1018,13 @@ cdef class GoldParse:
|
||||||
self.labels[i] = None
|
self.labels[i] = None
|
||||||
self.ner[i] = None
|
self.ner[i] = None
|
||||||
self.morphs[i] = set()
|
self.morphs[i] = set()
|
||||||
|
self.sent_starts[i] = 0
|
||||||
if gold_i is None:
|
if gold_i is None:
|
||||||
if i in i2j_multi:
|
if i in i2j_multi:
|
||||||
self.words[i] = words[i2j_multi[i]]
|
self.words[i] = words[i2j_multi[i]]
|
||||||
self.tags[i] = tags[i2j_multi[i]]
|
self.tags[i] = tags[i2j_multi[i]]
|
||||||
self.morphs[i] = morphs[i2j_multi[i]]
|
self.morphs[i] = morphs[i2j_multi[i]]
|
||||||
|
self.sent_starts[i] = sent_starts[i2j_multi[i]]
|
||||||
is_last = i2j_multi[i] != i2j_multi.get(i+1)
|
is_last = i2j_multi[i] != i2j_multi.get(i+1)
|
||||||
is_first = i2j_multi[i] != i2j_multi.get(i-1)
|
is_first = i2j_multi[i] != i2j_multi.get(i-1)
|
||||||
# Set next word in multi-token span as head, until last
|
# Set next word in multi-token span as head, until last
|
||||||
|
@ -1055,6 +1064,7 @@ cdef class GoldParse:
|
||||||
self.words[i] = words[gold_i]
|
self.words[i] = words[gold_i]
|
||||||
self.tags[i] = tags[gold_i]
|
self.tags[i] = tags[gold_i]
|
||||||
self.morphs[i] = morphs[gold_i]
|
self.morphs[i] = morphs[gold_i]
|
||||||
|
self.sent_starts[i] = sent_starts[gold_i]
|
||||||
if heads[gold_i] is None:
|
if heads[gold_i] is None:
|
||||||
self.heads[i] = None
|
self.heads[i] = None
|
||||||
else:
|
else:
|
||||||
|
@ -1091,21 +1101,6 @@ cdef class GoldParse:
|
||||||
"""
|
"""
|
||||||
return not nonproj.is_nonproj_tree(self.heads)
|
return not nonproj.is_nonproj_tree(self.heads)
|
||||||
|
|
||||||
property sent_starts:
|
|
||||||
def __get__(self):
|
|
||||||
return [self.c.sent_start[i] for i in range(self.length)]
|
|
||||||
|
|
||||||
def __set__(self, sent_starts):
|
|
||||||
for gold_i, is_sent_start in enumerate(sent_starts):
|
|
||||||
i = self.gold_to_cand[gold_i]
|
|
||||||
if i is not None:
|
|
||||||
if is_sent_start in (1, True):
|
|
||||||
self.c.sent_start[i] = 1
|
|
||||||
elif is_sent_start in (-1, False):
|
|
||||||
self.c.sent_start[i] = -1
|
|
||||||
else:
|
|
||||||
self.c.sent_start[i] = 0
|
|
||||||
|
|
||||||
|
|
||||||
def docs_to_json(docs, id=0):
|
def docs_to_json(docs, id=0):
|
||||||
"""Convert a list of Doc objects into the JSON-serializable format used by
|
"""Convert a list of Doc objects into the JSON-serializable format used by
|
||||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
|
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
|
||||||
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
|
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
|
||||||
|
from .pipes import SentenceRecognizer
|
||||||
from .morphologizer import Morphologizer
|
from .morphologizer import Morphologizer
|
||||||
from .entityruler import EntityRuler
|
from .entityruler import EntityRuler
|
||||||
from .hooks import SentenceSegmenter, SimilarityHook
|
from .hooks import SentenceSegmenter, SimilarityHook
|
||||||
|
@ -20,6 +21,7 @@ __all__ = [
|
||||||
"EntityRuler",
|
"EntityRuler",
|
||||||
"Sentencizer",
|
"Sentencizer",
|
||||||
"SentenceSegmenter",
|
"SentenceSegmenter",
|
||||||
|
"SentenceRecognizer",
|
||||||
"SimilarityHook",
|
"SimilarityHook",
|
||||||
"merge_entities",
|
"merge_entities",
|
||||||
"merge_noun_chunks",
|
"merge_noun_chunks",
|
||||||
|
|
|
@ -705,6 +705,169 @@ class Tagger(Pipe):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@component("sentrec", assigns=["token.is_sent_start"])
|
||||||
|
class SentenceRecognizer(Tagger):
|
||||||
|
"""Pipeline component for sentence segmentation.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/sentencerecognizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab, model=True, **cfg):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = model
|
||||||
|
self._rehearsal_model = None
|
||||||
|
self.cfg = OrderedDict(sorted(cfg.items()))
|
||||||
|
self.cfg.setdefault("cnn_maxout_pieces", 2)
|
||||||
|
self.cfg.setdefault("subword_features", True)
|
||||||
|
self.cfg.setdefault("token_vector_width", 12)
|
||||||
|
self.cfg.setdefault("conv_depth", 1)
|
||||||
|
self.cfg.setdefault("pretrained_vectors", None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
# labels are numbered by index internally, so this matches GoldParse
|
||||||
|
# and Example where the sentence-initial tag is 1 and other positions
|
||||||
|
# are 0
|
||||||
|
return tuple(["I", "S"])
|
||||||
|
|
||||||
|
def set_annotations(self, docs, batch_tag_ids, **_):
|
||||||
|
if isinstance(docs, Doc):
|
||||||
|
docs = [docs]
|
||||||
|
cdef Doc doc
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
|
if hasattr(doc_tag_ids, "get"):
|
||||||
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
for j, tag_id in enumerate(doc_tag_ids):
|
||||||
|
# Don't clobber existing sentence boundaries
|
||||||
|
if doc.c[j].sent_start == 0:
|
||||||
|
if tag_id == 1:
|
||||||
|
doc.c[j].sent_start = 1
|
||||||
|
else:
|
||||||
|
doc.c[j].sent_start = -1
|
||||||
|
|
||||||
|
def update(self, examples, drop=0., sgd=None, losses=None):
|
||||||
|
self.require_model()
|
||||||
|
examples = Example.to_example_objects(examples)
|
||||||
|
if losses is not None and self.name not in losses:
|
||||||
|
losses[self.name] = 0.
|
||||||
|
|
||||||
|
if not any(len(ex.doc) if ex.doc else 0 for ex in examples):
|
||||||
|
# Handle cases where there are no tokens in any docs.
|
||||||
|
return
|
||||||
|
|
||||||
|
tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples], drop=drop)
|
||||||
|
loss, d_tag_scores = self.get_loss(examples, tag_scores)
|
||||||
|
bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||||
|
|
||||||
|
if losses is not None:
|
||||||
|
losses[self.name] += loss
|
||||||
|
|
||||||
|
def get_loss(self, examples, scores):
|
||||||
|
scores = self.model.ops.flatten(scores)
|
||||||
|
tag_index = range(len(self.labels))
|
||||||
|
cdef int idx = 0
|
||||||
|
correct = numpy.zeros((scores.shape[0],), dtype="i")
|
||||||
|
guesses = scores.argmax(axis=1)
|
||||||
|
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
|
||||||
|
for ex in examples:
|
||||||
|
gold = ex.gold
|
||||||
|
for sent_start in gold.sent_starts:
|
||||||
|
if sent_start is None:
|
||||||
|
correct[idx] = guesses[idx]
|
||||||
|
elif sent_start in tag_index:
|
||||||
|
correct[idx] = sent_start
|
||||||
|
else:
|
||||||
|
correct[idx] = 0
|
||||||
|
known_labels[idx] = 0.
|
||||||
|
idx += 1
|
||||||
|
correct = self.model.ops.xp.array(correct, dtype="i")
|
||||||
|
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||||
|
d_scores *= self.model.ops.asarray(known_labels)
|
||||||
|
loss = (d_scores**2).sum()
|
||||||
|
docs = [ex.doc for ex in examples]
|
||||||
|
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
||||||
|
return float(loss), d_scores
|
||||||
|
|
||||||
|
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
|
||||||
|
**kwargs):
|
||||||
|
cdef Vocab vocab = self.vocab
|
||||||
|
if self.model is True:
|
||||||
|
for hp in ["token_vector_width", "conv_depth"]:
|
||||||
|
if hp in kwargs:
|
||||||
|
self.cfg[hp] = kwargs[hp]
|
||||||
|
self.model = self.Model(len(self.labels), **self.cfg)
|
||||||
|
if sgd is None:
|
||||||
|
sgd = self.create_optimizer()
|
||||||
|
return sgd
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def Model(cls, n_tags, **cfg):
|
||||||
|
return build_tagger_model(n_tags, **cfg)
|
||||||
|
|
||||||
|
def add_label(self, label, values=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def use_params(self, params):
|
||||||
|
with self.model.use_params(params):
|
||||||
|
yield
|
||||||
|
|
||||||
|
def to_bytes(self, exclude=tuple(), **kwargs):
|
||||||
|
serialize = OrderedDict()
|
||||||
|
if self.model not in (None, True, False):
|
||||||
|
serialize["model"] = self.model.to_bytes
|
||||||
|
serialize["vocab"] = self.vocab.to_bytes
|
||||||
|
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||||
|
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||||
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
|
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
|
||||||
|
def load_model(b):
|
||||||
|
if self.model is True:
|
||||||
|
self.model = self.Model(len(self.labels), **self.cfg)
|
||||||
|
try:
|
||||||
|
self.model.from_bytes(b)
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149)
|
||||||
|
|
||||||
|
deserialize = OrderedDict((
|
||||||
|
("vocab", lambda b: self.vocab.from_bytes(b)),
|
||||||
|
("cfg", lambda b: self.cfg.update(srsly.json_loads(b))),
|
||||||
|
("model", lambda b: load_model(b)),
|
||||||
|
))
|
||||||
|
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||||
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
|
serialize = OrderedDict((
|
||||||
|
("vocab", lambda p: self.vocab.to_disk(p)),
|
||||||
|
("model", lambda p: p.open("wb").write(self.model.to_bytes())),
|
||||||
|
("cfg", lambda p: srsly.write_json(p, self.cfg))
|
||||||
|
))
|
||||||
|
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||||
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
|
def load_model(p):
|
||||||
|
if self.model is True:
|
||||||
|
self.model = self.Model(len(self.labels), **self.cfg)
|
||||||
|
with p.open("rb") as file_:
|
||||||
|
try:
|
||||||
|
self.model.from_bytes(file_.read())
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149)
|
||||||
|
|
||||||
|
deserialize = OrderedDict((
|
||||||
|
("cfg", lambda p: self.cfg.update(_load_cfg(p))),
|
||||||
|
("vocab", lambda p: self.vocab.from_disk(p)),
|
||||||
|
("model", load_model),
|
||||||
|
))
|
||||||
|
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||||
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
@component("nn_labeller")
|
@component("nn_labeller")
|
||||||
class MultitaskObjective(Tagger):
|
class MultitaskObjective(Tagger):
|
||||||
"""Experimental: Assist training of a parser or tagger, by training a
|
"""Experimental: Assist training of a parser or tagger, by training a
|
||||||
|
@ -1589,4 +1752,4 @@ Language.factories["parser"] = lambda nlp, **cfg: DependencyParser.from_nlp(nlp,
|
||||||
Language.factories["ner"] = lambda nlp, **cfg: EntityRecognizer.from_nlp(nlp, **cfg)
|
Language.factories["ner"] = lambda nlp, **cfg: EntityRecognizer.from_nlp(nlp, **cfg)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer"]
|
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer", "SentenceRecognizer"]
|
||||||
|
|
|
@ -84,6 +84,7 @@ class Scorer(object):
|
||||||
self.labelled = PRFScore()
|
self.labelled = PRFScore()
|
||||||
self.labelled_per_dep = dict()
|
self.labelled_per_dep = dict()
|
||||||
self.tags = PRFScore()
|
self.tags = PRFScore()
|
||||||
|
self.sent_starts = PRFScore()
|
||||||
self.ner = PRFScore()
|
self.ner = PRFScore()
|
||||||
self.ner_per_ents = dict()
|
self.ner_per_ents = dict()
|
||||||
self.eval_punct = eval_punct
|
self.eval_punct = eval_punct
|
||||||
|
@ -113,6 +114,27 @@ class Scorer(object):
|
||||||
"""
|
"""
|
||||||
return self.tags.fscore * 100
|
return self.tags.fscore * 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sent_p(self):
|
||||||
|
"""RETURNS (float): F-score for identification of sentence starts.
|
||||||
|
i.e. `Token.is_sent_start`).
|
||||||
|
"""
|
||||||
|
return self.sent_starts.precision * 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sent_r(self):
|
||||||
|
"""RETURNS (float): F-score for identification of sentence starts.
|
||||||
|
i.e. `Token.is_sent_start`).
|
||||||
|
"""
|
||||||
|
return self.sent_starts.recall * 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sent_f(self):
|
||||||
|
"""RETURNS (float): F-score for identification of sentence starts.
|
||||||
|
i.e. `Token.is_sent_start`).
|
||||||
|
"""
|
||||||
|
return self.sent_starts.fscore * 100
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def token_acc(self):
|
def token_acc(self):
|
||||||
"""RETURNS (float): Tokenization accuracy."""
|
"""RETURNS (float): Tokenization accuracy."""
|
||||||
|
@ -212,6 +234,9 @@ class Scorer(object):
|
||||||
"ents_f": self.ents_f,
|
"ents_f": self.ents_f,
|
||||||
"ents_per_type": self.ents_per_type,
|
"ents_per_type": self.ents_per_type,
|
||||||
"tags_acc": self.tags_acc,
|
"tags_acc": self.tags_acc,
|
||||||
|
"sent_p": self.sent_p,
|
||||||
|
"sent_r": self.sent_r,
|
||||||
|
"sent_f": self.sent_f,
|
||||||
"token_acc": self.token_acc,
|
"token_acc": self.token_acc,
|
||||||
"textcat_score": self.textcat_score,
|
"textcat_score": self.textcat_score,
|
||||||
"textcats_per_cat": self.textcats_per_cat,
|
"textcats_per_cat": self.textcats_per_cat,
|
||||||
|
@ -242,9 +267,12 @@ class Scorer(object):
|
||||||
gold_deps = set()
|
gold_deps = set()
|
||||||
gold_deps_per_dep = {}
|
gold_deps_per_dep = {}
|
||||||
gold_tags = set()
|
gold_tags = set()
|
||||||
|
gold_sent_starts = set()
|
||||||
gold_ents = set(tags_to_entities(orig.entities))
|
gold_ents = set(tags_to_entities(orig.entities))
|
||||||
for id_, tag, head, dep in zip(orig.ids, orig.tags, orig.heads, orig.deps):
|
for id_, tag, head, dep, sent_start in zip(orig.ids, orig.tags, orig.heads, orig.deps, orig.sent_starts):
|
||||||
gold_tags.add((id_, tag))
|
gold_tags.add((id_, tag))
|
||||||
|
if sent_start:
|
||||||
|
gold_sent_starts.add(id_)
|
||||||
if dep not in (None, "") and dep.lower() not in punct_labels:
|
if dep not in (None, "") and dep.lower() not in punct_labels:
|
||||||
gold_deps.add((id_, head, dep.lower()))
|
gold_deps.add((id_, head, dep.lower()))
|
||||||
if dep.lower() not in self.labelled_per_dep:
|
if dep.lower() not in self.labelled_per_dep:
|
||||||
|
@ -255,6 +283,7 @@ class Scorer(object):
|
||||||
cand_deps = set()
|
cand_deps = set()
|
||||||
cand_deps_per_dep = {}
|
cand_deps_per_dep = {}
|
||||||
cand_tags = set()
|
cand_tags = set()
|
||||||
|
cand_sent_starts = set()
|
||||||
for token in doc:
|
for token in doc:
|
||||||
if token.orth_.isspace():
|
if token.orth_.isspace():
|
||||||
continue
|
continue
|
||||||
|
@ -264,6 +293,8 @@ class Scorer(object):
|
||||||
else:
|
else:
|
||||||
self.tokens.tp += 1
|
self.tokens.tp += 1
|
||||||
cand_tags.add((gold_i, token.tag_))
|
cand_tags.add((gold_i, token.tag_))
|
||||||
|
if token.is_sent_start:
|
||||||
|
cand_sent_starts.add(gold_i)
|
||||||
if token.dep_.lower() not in punct_labels and token.orth_.strip():
|
if token.dep_.lower() not in punct_labels and token.orth_.strip():
|
||||||
gold_head = gold.cand_to_gold[token.head.i]
|
gold_head = gold.cand_to_gold[token.head.i]
|
||||||
# None is indistinct, so we can't just add it to the set
|
# None is indistinct, so we can't just add it to the set
|
||||||
|
@ -308,6 +339,7 @@ class Scorer(object):
|
||||||
# Score for all ents
|
# Score for all ents
|
||||||
self.ner.score_set(cand_ents, gold_ents)
|
self.ner.score_set(cand_ents, gold_ents)
|
||||||
self.tags.score_set(cand_tags, gold_tags)
|
self.tags.score_set(cand_tags, gold_tags)
|
||||||
|
self.sent_starts.score_set(cand_sent_starts, gold_sent_starts)
|
||||||
self.labelled.score_set(cand_deps, gold_deps)
|
self.labelled.score_set(cand_deps, gold_deps)
|
||||||
for dep in self.labelled_per_dep:
|
for dep in self.labelled_per_dep:
|
||||||
self.labelled_per_dep[dep].score_set(cand_deps_per_dep.get(dep, set()), gold_deps_per_dep.get(dep, set()))
|
self.labelled_per_dep[dep].score_set(cand_deps_per_dep.get(dep, set()), gold_deps_per_dep.get(dep, set()))
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
||||||
from spacy.pipeline import Tensorizer, TextCategorizer
|
from spacy.pipeline import Tensorizer, TextCategorizer, SentenceRecognizer
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
@ -144,3 +144,10 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
|
||||||
parser.to_bytes(cfg=False, exclude=["vocab"])
|
parser.to_bytes(cfg=False, exclude=["vocab"])
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
get_new_parser().from_bytes(parser.to_bytes(exclude=["vocab"]), cfg=False)
|
get_new_parser().from_bytes(parser.to_bytes(exclude=["vocab"]), cfg=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_sentencerecognizer(en_vocab):
|
||||||
|
sr = SentenceRecognizer(en_vocab)
|
||||||
|
sr_b = sr.to_bytes()
|
||||||
|
sr_d = SentenceRecognizer(en_vocab).from_bytes(sr_b)
|
||||||
|
assert sr.to_bytes() == sr_d.to_bytes()
|
||||||
|
|
Loading…
Reference in New Issue