mirror of https://github.com/explosion/spaCy.git
Add a tagger-based SentenceRecognizer (#4713)
* Add sent_starts to GoldParse * Add SentTagger pipeline component Add `SentTagger` pipeline component as a subclass of `Tagger`. * Model reduces default parameters from `Tagger` to be small and fast * Hard-coded set of two labels: * S (1): token at beginning of sentence * I (0): all other sentence positions * Sets `token.sent_start` values * Add sentence segmentation to Scorer Report `sent_p/r/f` for sentence boundaries, which may be provided by various pipeline components. * Add sentence segmentation to CLI evaluate * Add senttagger metrics/scoring to train CLI * Rename SentTagger to SentenceRecognizer * Add SentenceRecognizer to spacy.pipes imports * Add SentenceRecognizer serialization test * Shorten component name to sentrec * Remove duplicates from train CLI output metrics
This commit is contained in:
parent
9efd3ccbef
commit
b841d3fe75
|
@ -61,6 +61,9 @@ def evaluate(
|
|||
"NER R": "%.2f" % scorer.ents_r,
|
||||
"NER F": "%.2f" % scorer.ents_f,
|
||||
"Textcat": "%.2f" % scorer.textcat_score,
|
||||
"Sent P": "%.2f" % scorer.sent_p,
|
||||
"Sent R": "%.2f" % scorer.sent_r,
|
||||
"Sent F": "%.2f" % scorer.sent_f,
|
||||
}
|
||||
msg.table(results, title="Results")
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import srsly
|
|||
from wasabi import msg
|
||||
import contextlib
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
from .._ml import create_default_optimizer
|
||||
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
|
||||
|
@ -585,11 +586,13 @@ def _find_best(experiment_dir, component):
|
|||
|
||||
def _get_metrics(component):
|
||||
if component == "parser":
|
||||
return ("las", "uas", "token_acc")
|
||||
return ("las", "uas", "token_acc", "sent_f")
|
||||
elif component == "tagger":
|
||||
return ("tags_acc",)
|
||||
elif component == "ner":
|
||||
return ("ents_f", "ents_p", "ents_r")
|
||||
elif component == "sentrec":
|
||||
return ("sent_p", "sent_r", "sent_f",)
|
||||
return ("token_acc",)
|
||||
|
||||
|
||||
|
@ -601,14 +604,17 @@ def _configure_training_output(pipeline, use_gpu, has_beam_widths):
|
|||
row_head.extend(["Tag Loss ", " Tag % "])
|
||||
output_stats.extend(["tag_loss", "tags_acc"])
|
||||
elif pipe == "parser":
|
||||
row_head.extend(["Dep Loss ", " UAS ", " LAS "])
|
||||
output_stats.extend(["dep_loss", "uas", "las"])
|
||||
row_head.extend(["Dep Loss ", " UAS ", " LAS ", "Sent P", "Sent R", "Sent F"])
|
||||
output_stats.extend(["dep_loss", "uas", "las", "sent_p", "sent_r", "sent_f"])
|
||||
elif pipe == "ner":
|
||||
row_head.extend(["NER Loss ", "NER P ", "NER R ", "NER F "])
|
||||
output_stats.extend(["ner_loss", "ents_p", "ents_r", "ents_f"])
|
||||
elif pipe == "textcat":
|
||||
row_head.extend(["Textcat Loss", "Textcat"])
|
||||
output_stats.extend(["textcat_loss", "textcat_score"])
|
||||
elif pipe == "sentrec":
|
||||
row_head.extend(["Sentrec Loss", "Sent P", "Sent R", "Sent F"])
|
||||
output_stats.extend(["sentrec_loss", "sent_p", "sent_r", "sent_f"])
|
||||
row_head.extend(["Token %", "CPU WPS"])
|
||||
output_stats.extend(["token_acc", "cpu_wps"])
|
||||
|
||||
|
@ -618,7 +624,12 @@ def _configure_training_output(pipeline, use_gpu, has_beam_widths):
|
|||
|
||||
if has_beam_widths:
|
||||
row_head.insert(1, "Beam W.")
|
||||
return row_head, output_stats
|
||||
# remove duplicates
|
||||
row_head_dict = OrderedDict()
|
||||
row_head_dict.update({k: 1 for k in row_head})
|
||||
output_stats_dict = OrderedDict()
|
||||
output_stats_dict.update({k: 1 for k in output_stats})
|
||||
return row_head_dict.keys(), output_stats_dict.keys()
|
||||
|
||||
|
||||
def _get_progress(
|
||||
|
@ -631,6 +642,7 @@ def _get_progress(
|
|||
scores["ner_loss"] = losses.get("ner", 0.0)
|
||||
scores["tag_loss"] = losses.get("tagger", 0.0)
|
||||
scores["textcat_loss"] = losses.get("textcat", 0.0)
|
||||
scores["sentrec_loss"] = losses.get("sentrec", 0.0)
|
||||
scores["cpu_wps"] = cpu_wps
|
||||
scores["gpu_wps"] = gpu_wps or 0.0
|
||||
scores.update(dev_scores)
|
||||
|
|
|
@ -26,6 +26,7 @@ cdef class GoldParse:
|
|||
cdef public list words
|
||||
cdef public list tags
|
||||
cdef public list morphs
|
||||
cdef public list sent_starts
|
||||
cdef public list heads
|
||||
cdef public list labels
|
||||
cdef public dict orths
|
||||
|
|
|
@ -497,9 +497,9 @@ def json_to_examples(doc):
|
|||
ner.append(token.get("ner", "-"))
|
||||
morphs.append(token.get("morph", {}))
|
||||
if i == 0:
|
||||
sent_starts.append(True)
|
||||
sent_starts.append(1)
|
||||
else:
|
||||
sent_starts.append(False)
|
||||
sent_starts.append(0)
|
||||
if "brackets" in sent:
|
||||
brackets.extend((b["first"] + sent_start_i,
|
||||
b["last"] + sent_start_i, b["label"])
|
||||
|
@ -759,7 +759,7 @@ cdef class Example:
|
|||
t = self.token_annotation
|
||||
split_examples = []
|
||||
for i in range(len(t.words)):
|
||||
if i > 0 and t.sent_starts[i] == True:
|
||||
if i > 0 and t.sent_starts[i] == 1:
|
||||
s_example.set_token_annotation(ids=s_ids,
|
||||
words=s_words, tags=s_tags, heads=s_heads, deps=s_deps,
|
||||
entities=s_ents, morphs=s_morphs,
|
||||
|
@ -892,6 +892,7 @@ cdef class GoldParse:
|
|||
deps=token_annotation.deps,
|
||||
entities=token_annotation.entities,
|
||||
morphs=token_annotation.morphs,
|
||||
sent_starts=token_annotation.sent_starts,
|
||||
cats=doc_annotation.cats,
|
||||
links=doc_annotation.links,
|
||||
make_projective=make_projective)
|
||||
|
@ -902,12 +903,13 @@ cdef class GoldParse:
|
|||
ids = list(range(len(self.words)))
|
||||
|
||||
return TokenAnnotation(ids=ids, words=self.words, tags=self.tags,
|
||||
heads=self.heads, deps=self.labels, entities=self.ner,
|
||||
morphs=self.morphs)
|
||||
heads=self.heads, deps=self.labels,
|
||||
entities=self.ner, morphs=self.morphs,
|
||||
sent_starts=self.sent_starts)
|
||||
|
||||
def __init__(self, doc, words=None, tags=None, morphs=None,
|
||||
heads=None, deps=None, entities=None, make_projective=False,
|
||||
cats=None, links=None):
|
||||
heads=None, deps=None, entities=None, sent_starts=None,
|
||||
make_projective=False, cats=None, links=None):
|
||||
"""Create a GoldParse. The fields will not be initialized if len(doc) is zero.
|
||||
|
||||
doc (Doc): The document the annotations refer to.
|
||||
|
@ -920,6 +922,8 @@ cdef class GoldParse:
|
|||
entities (iterable): A sequence of named entity annotations, either as
|
||||
BILUO tag strings, or as `(start_char, end_char, label)` tuples,
|
||||
representing the entity positions.
|
||||
sent_starts (iterable): A sequence of sentence position tags, 1 for
|
||||
the first word in a sentence, 0 for all others.
|
||||
cats (dict): Labels for text classification. Each key in the dictionary
|
||||
may be a string or an int, or a `(start_char, end_char, label)`
|
||||
tuple, indicating that the label is applied to only part of the
|
||||
|
@ -956,6 +960,8 @@ cdef class GoldParse:
|
|||
deps = [None for _ in words]
|
||||
if not morphs:
|
||||
morphs = [None for _ in words]
|
||||
if not sent_starts:
|
||||
sent_starts = [None for _ in words]
|
||||
if entities is None:
|
||||
entities = ["-" for _ in words]
|
||||
elif len(entities) == 0:
|
||||
|
@ -982,6 +988,7 @@ cdef class GoldParse:
|
|||
self.labels = [None] * len(doc)
|
||||
self.ner = [None] * len(doc)
|
||||
self.morphs = [None] * len(doc)
|
||||
self.sent_starts = [None] * len(doc)
|
||||
|
||||
# This needs to be done before we align the words
|
||||
if make_projective and heads is not None and deps is not None:
|
||||
|
@ -1000,7 +1007,7 @@ cdef class GoldParse:
|
|||
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
|
||||
|
||||
self.orig = TokenAnnotation(ids=list(range(len(words))), words=words, tags=tags,
|
||||
heads=heads, deps=deps, entities=entities, morphs=morphs,
|
||||
heads=heads, deps=deps, entities=entities, morphs=morphs, sent_starts=sent_starts,
|
||||
brackets=[])
|
||||
|
||||
for i, gold_i in enumerate(self.cand_to_gold):
|
||||
|
@ -1011,11 +1018,13 @@ cdef class GoldParse:
|
|||
self.labels[i] = None
|
||||
self.ner[i] = None
|
||||
self.morphs[i] = set()
|
||||
self.sent_starts[i] = 0
|
||||
if gold_i is None:
|
||||
if i in i2j_multi:
|
||||
self.words[i] = words[i2j_multi[i]]
|
||||
self.tags[i] = tags[i2j_multi[i]]
|
||||
self.morphs[i] = morphs[i2j_multi[i]]
|
||||
self.sent_starts[i] = sent_starts[i2j_multi[i]]
|
||||
is_last = i2j_multi[i] != i2j_multi.get(i+1)
|
||||
is_first = i2j_multi[i] != i2j_multi.get(i-1)
|
||||
# Set next word in multi-token span as head, until last
|
||||
|
@ -1055,6 +1064,7 @@ cdef class GoldParse:
|
|||
self.words[i] = words[gold_i]
|
||||
self.tags[i] = tags[gold_i]
|
||||
self.morphs[i] = morphs[gold_i]
|
||||
self.sent_starts[i] = sent_starts[gold_i]
|
||||
if heads[gold_i] is None:
|
||||
self.heads[i] = None
|
||||
else:
|
||||
|
@ -1091,21 +1101,6 @@ cdef class GoldParse:
|
|||
"""
|
||||
return not nonproj.is_nonproj_tree(self.heads)
|
||||
|
||||
property sent_starts:
|
||||
def __get__(self):
|
||||
return [self.c.sent_start[i] for i in range(self.length)]
|
||||
|
||||
def __set__(self, sent_starts):
|
||||
for gold_i, is_sent_start in enumerate(sent_starts):
|
||||
i = self.gold_to_cand[gold_i]
|
||||
if i is not None:
|
||||
if is_sent_start in (1, True):
|
||||
self.c.sent_start[i] = 1
|
||||
elif is_sent_start in (-1, False):
|
||||
self.c.sent_start[i] = -1
|
||||
else:
|
||||
self.c.sent_start[i] = 0
|
||||
|
||||
|
||||
def docs_to_json(docs, id=0):
|
||||
"""Convert a list of Doc objects into the JSON-serializable format used by
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
|
||||
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
|
||||
from .pipes import SentenceRecognizer
|
||||
from .morphologizer import Morphologizer
|
||||
from .entityruler import EntityRuler
|
||||
from .hooks import SentenceSegmenter, SimilarityHook
|
||||
|
@ -20,6 +21,7 @@ __all__ = [
|
|||
"EntityRuler",
|
||||
"Sentencizer",
|
||||
"SentenceSegmenter",
|
||||
"SentenceRecognizer",
|
||||
"SimilarityHook",
|
||||
"merge_entities",
|
||||
"merge_noun_chunks",
|
||||
|
|
|
@ -705,6 +705,169 @@ class Tagger(Pipe):
|
|||
return self
|
||||
|
||||
|
||||
@component("sentrec", assigns=["token.is_sent_start"])
|
||||
class SentenceRecognizer(Tagger):
|
||||
"""Pipeline component for sentence segmentation.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer
|
||||
"""
|
||||
|
||||
def __init__(self, vocab, model=True, **cfg):
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self._rehearsal_model = None
|
||||
self.cfg = OrderedDict(sorted(cfg.items()))
|
||||
self.cfg.setdefault("cnn_maxout_pieces", 2)
|
||||
self.cfg.setdefault("subword_features", True)
|
||||
self.cfg.setdefault("token_vector_width", 12)
|
||||
self.cfg.setdefault("conv_depth", 1)
|
||||
self.cfg.setdefault("pretrained_vectors", None)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
# labels are numbered by index internally, so this matches GoldParse
|
||||
# and Example where the sentence-initial tag is 1 and other positions
|
||||
# are 0
|
||||
return tuple(["I", "S"])
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids, **_):
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
for i, doc in enumerate(docs):
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
if hasattr(doc_tag_ids, "get"):
|
||||
doc_tag_ids = doc_tag_ids.get()
|
||||
for j, tag_id in enumerate(doc_tag_ids):
|
||||
# Don't clobber existing sentence boundaries
|
||||
if doc.c[j].sent_start == 0:
|
||||
if tag_id == 1:
|
||||
doc.c[j].sent_start = 1
|
||||
else:
|
||||
doc.c[j].sent_start = -1
|
||||
|
||||
def update(self, examples, drop=0., sgd=None, losses=None):
|
||||
self.require_model()
|
||||
examples = Example.to_example_objects(examples)
|
||||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
|
||||
if not any(len(ex.doc) if ex.doc else 0 for ex in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
|
||||
tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples], drop=drop)
|
||||
loss, d_tag_scores = self.get_loss(examples, tag_scores)
|
||||
bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
scores = self.model.ops.flatten(scores)
|
||||
tag_index = range(len(self.labels))
|
||||
cdef int idx = 0
|
||||
correct = numpy.zeros((scores.shape[0],), dtype="i")
|
||||
guesses = scores.argmax(axis=1)
|
||||
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
|
||||
for ex in examples:
|
||||
gold = ex.gold
|
||||
for sent_start in gold.sent_starts:
|
||||
if sent_start is None:
|
||||
correct[idx] = guesses[idx]
|
||||
elif sent_start in tag_index:
|
||||
correct[idx] = sent_start
|
||||
else:
|
||||
correct[idx] = 0
|
||||
known_labels[idx] = 0.
|
||||
idx += 1
|
||||
correct = self.model.ops.xp.array(correct, dtype="i")
|
||||
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||
d_scores *= self.model.ops.asarray(known_labels)
|
||||
loss = (d_scores**2).sum()
|
||||
docs = [ex.doc for ex in examples]
|
||||
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
||||
return float(loss), d_scores
|
||||
|
||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
|
||||
**kwargs):
|
||||
cdef Vocab vocab = self.vocab
|
||||
if self.model is True:
|
||||
for hp in ["token_vector_width", "conv_depth"]:
|
||||
if hp in kwargs:
|
||||
self.cfg[hp] = kwargs[hp]
|
||||
self.model = self.Model(len(self.labels), **self.cfg)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
@classmethod
|
||||
def Model(cls, n_tags, **cfg):
|
||||
return build_tagger_model(n_tags, **cfg)
|
||||
|
||||
def add_label(self, label, values=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def use_params(self, params):
|
||||
with self.model.use_params(params):
|
||||
yield
|
||||
|
||||
def to_bytes(self, exclude=tuple(), **kwargs):
|
||||
serialize = OrderedDict()
|
||||
if self.model not in (None, True, False):
|
||||
serialize["model"] = self.model.to_bytes
|
||||
serialize["vocab"] = self.vocab.to_bytes
|
||||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
|
||||
def load_model(b):
|
||||
if self.model is True:
|
||||
self.model = self.Model(len(self.labels), **self.cfg)
|
||||
try:
|
||||
self.model.from_bytes(b)
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
|
||||
deserialize = OrderedDict((
|
||||
("vocab", lambda b: self.vocab.from_bytes(b)),
|
||||
("cfg", lambda b: self.cfg.update(srsly.json_loads(b))),
|
||||
("model", lambda b: load_model(b)),
|
||||
))
|
||||
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||
serialize = OrderedDict((
|
||||
("vocab", lambda p: self.vocab.to_disk(p)),
|
||||
("model", lambda p: p.open("wb").write(self.model.to_bytes())),
|
||||
("cfg", lambda p: srsly.write_json(p, self.cfg))
|
||||
))
|
||||
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||
def load_model(p):
|
||||
if self.model is True:
|
||||
self.model = self.Model(len(self.labels), **self.cfg)
|
||||
with p.open("rb") as file_:
|
||||
try:
|
||||
self.model.from_bytes(file_.read())
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149)
|
||||
|
||||
deserialize = OrderedDict((
|
||||
("cfg", lambda p: self.cfg.update(_load_cfg(p))),
|
||||
("vocab", lambda p: self.vocab.from_disk(p)),
|
||||
("model", load_model),
|
||||
))
|
||||
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
return self
|
||||
|
||||
|
||||
@component("nn_labeller")
|
||||
class MultitaskObjective(Tagger):
|
||||
"""Experimental: Assist training of a parser or tagger, by training a
|
||||
|
@ -1589,4 +1752,4 @@ Language.factories["parser"] = lambda nlp, **cfg: DependencyParser.from_nlp(nlp,
|
|||
Language.factories["ner"] = lambda nlp, **cfg: EntityRecognizer.from_nlp(nlp, **cfg)
|
||||
|
||||
|
||||
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer"]
|
||||
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer", "SentenceRecognizer"]
|
||||
|
|
|
@ -84,6 +84,7 @@ class Scorer(object):
|
|||
self.labelled = PRFScore()
|
||||
self.labelled_per_dep = dict()
|
||||
self.tags = PRFScore()
|
||||
self.sent_starts = PRFScore()
|
||||
self.ner = PRFScore()
|
||||
self.ner_per_ents = dict()
|
||||
self.eval_punct = eval_punct
|
||||
|
@ -113,6 +114,27 @@ class Scorer(object):
|
|||
"""
|
||||
return self.tags.fscore * 100
|
||||
|
||||
@property
|
||||
def sent_p(self):
|
||||
"""RETURNS (float): F-score for identification of sentence starts.
|
||||
i.e. `Token.is_sent_start`).
|
||||
"""
|
||||
return self.sent_starts.precision * 100
|
||||
|
||||
@property
|
||||
def sent_r(self):
|
||||
"""RETURNS (float): F-score for identification of sentence starts.
|
||||
i.e. `Token.is_sent_start`).
|
||||
"""
|
||||
return self.sent_starts.recall * 100
|
||||
|
||||
@property
|
||||
def sent_f(self):
|
||||
"""RETURNS (float): F-score for identification of sentence starts.
|
||||
i.e. `Token.is_sent_start`).
|
||||
"""
|
||||
return self.sent_starts.fscore * 100
|
||||
|
||||
@property
|
||||
def token_acc(self):
|
||||
"""RETURNS (float): Tokenization accuracy."""
|
||||
|
@ -212,6 +234,9 @@ class Scorer(object):
|
|||
"ents_f": self.ents_f,
|
||||
"ents_per_type": self.ents_per_type,
|
||||
"tags_acc": self.tags_acc,
|
||||
"sent_p": self.sent_p,
|
||||
"sent_r": self.sent_r,
|
||||
"sent_f": self.sent_f,
|
||||
"token_acc": self.token_acc,
|
||||
"textcat_score": self.textcat_score,
|
||||
"textcats_per_cat": self.textcats_per_cat,
|
||||
|
@ -242,9 +267,12 @@ class Scorer(object):
|
|||
gold_deps = set()
|
||||
gold_deps_per_dep = {}
|
||||
gold_tags = set()
|
||||
gold_sent_starts = set()
|
||||
gold_ents = set(tags_to_entities(orig.entities))
|
||||
for id_, tag, head, dep in zip(orig.ids, orig.tags, orig.heads, orig.deps):
|
||||
for id_, tag, head, dep, sent_start in zip(orig.ids, orig.tags, orig.heads, orig.deps, orig.sent_starts):
|
||||
gold_tags.add((id_, tag))
|
||||
if sent_start:
|
||||
gold_sent_starts.add(id_)
|
||||
if dep not in (None, "") and dep.lower() not in punct_labels:
|
||||
gold_deps.add((id_, head, dep.lower()))
|
||||
if dep.lower() not in self.labelled_per_dep:
|
||||
|
@ -255,6 +283,7 @@ class Scorer(object):
|
|||
cand_deps = set()
|
||||
cand_deps_per_dep = {}
|
||||
cand_tags = set()
|
||||
cand_sent_starts = set()
|
||||
for token in doc:
|
||||
if token.orth_.isspace():
|
||||
continue
|
||||
|
@ -264,6 +293,8 @@ class Scorer(object):
|
|||
else:
|
||||
self.tokens.tp += 1
|
||||
cand_tags.add((gold_i, token.tag_))
|
||||
if token.is_sent_start:
|
||||
cand_sent_starts.add(gold_i)
|
||||
if token.dep_.lower() not in punct_labels and token.orth_.strip():
|
||||
gold_head = gold.cand_to_gold[token.head.i]
|
||||
# None is indistinct, so we can't just add it to the set
|
||||
|
@ -308,6 +339,7 @@ class Scorer(object):
|
|||
# Score for all ents
|
||||
self.ner.score_set(cand_ents, gold_ents)
|
||||
self.tags.score_set(cand_tags, gold_tags)
|
||||
self.sent_starts.score_set(cand_sent_starts, gold_sent_starts)
|
||||
self.labelled.score_set(cand_deps, gold_deps)
|
||||
for dep in self.labelled_per_dep:
|
||||
self.labelled_per_dep[dep].score_set(cand_deps_per_dep.get(dep, set()), gold_deps_per_dep.get(dep, set()))
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import pytest
|
||||
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
||||
from spacy.pipeline import Tensorizer, TextCategorizer
|
||||
from spacy.pipeline import Tensorizer, TextCategorizer, SentenceRecognizer
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -144,3 +144,10 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
|
|||
parser.to_bytes(cfg=False, exclude=["vocab"])
|
||||
with pytest.raises(ValueError):
|
||||
get_new_parser().from_bytes(parser.to_bytes(exclude=["vocab"]), cfg=False)
|
||||
|
||||
|
||||
def test_serialize_sentencerecognizer(en_vocab):
|
||||
sr = SentenceRecognizer(en_vocab)
|
||||
sr_b = sr.to_bytes()
|
||||
sr_d = SentenceRecognizer(en_vocab).from_bytes(sr_b)
|
||||
assert sr.to_bytes() == sr_d.to_bytes()
|
||||
|
|
Loading…
Reference in New Issue