add initialize method for entity_ruler

This commit is contained in:
svlandeg 2020-10-05 14:59:13 +02:00
parent e3acad6264
commit 251b3eb4e5
3 changed files with 32 additions and 2 deletions

View File

@ -456,6 +456,8 @@ class Errors:
"issue tracker: http://github.com/explosion/spaCy/issues")
# TODO: fix numbering after merging develop into master
E900 = ("Patterns for component '{name}' not initialized. This can be fixed "
"by calling 'add_patterns' or 'initialize'.")
E092 = ("The sentence-per-line IOB/IOB2 file is not formatted correctly. "
"Try checking whitespace and delimiters. See "
"https://nightly.spacy.io/api/cli#convert")

View File

@ -1,7 +1,8 @@
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable
from collections import defaultdict
from pathlib import Path
import srsly
from spacy.training import Example
from ..language import Language
from ..errors import Errors
@ -133,6 +134,7 @@ class EntityRuler:
DOCS: https://nightly.spacy.io/api/entityruler#call
"""
self._require_patterns()
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end]
@ -183,6 +185,27 @@ class EntityRuler:
all_labels.add(l)
return tuple(all_labels)
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
patterns_path: Optional[Path] = None
):
"""Initialize the pipe for training.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
patterns_path: Path to serialized patterns.
DOCS (TODO): https://nightly.spacy.io/api/entityruler#initialize
"""
if patterns_path:
patterns = srsly.read_jsonl(patterns_path)
self.add_patterns(patterns)
@property
def ent_ids(self) -> Tuple[str, ...]:
"""All entity ids present in the match patterns `id` properties
@ -292,6 +315,11 @@ class EntityRuler:
self.phrase_patterns = defaultdict(list)
self._ent_ids = defaultdict(dict)
def _require_patterns(self) -> None:
"""Raise an error if the component has no patterns."""
if not self.patterns or list(self.patterns) == [""]:
raise ValueError(Errors.E900.format(name=self.name))
def _split_label(self, label: str) -> Tuple[str, str]:
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep

View File

@ -49,7 +49,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
nlp.resume_training(sgd=optimizer)
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
logger.info("Initialized pipeline components")
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
return nlp