mirror of https://github.com/explosion/spaCy.git
add initialize method for entity_ruler
This commit is contained in:
parent
e3acad6264
commit
251b3eb4e5
|
@ -456,6 +456,8 @@ class Errors:
|
|||
"issue tracker: http://github.com/explosion/spaCy/issues")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E900 = ("Patterns for component '{name}' not initialized. This can be fixed "
|
||||
"by calling 'add_patterns' or 'initialize'.")
|
||||
E092 = ("The sentence-per-line IOB/IOB2 file is not formatted correctly. "
|
||||
"Try checking whitespace and delimiters. See "
|
||||
"https://nightly.spacy.io/api/cli#convert")
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any
|
||||
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import srsly
|
||||
from spacy.training import Example
|
||||
|
||||
from ..language import Language
|
||||
from ..errors import Errors
|
||||
|
@ -133,6 +134,7 @@ class EntityRuler:
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/entityruler#call
|
||||
"""
|
||||
self._require_patterns()
|
||||
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
|
||||
matches = set(
|
||||
[(m_id, start, end) for m_id, start, end in matches if start != end]
|
||||
|
@ -183,6 +185,27 @@ class EntityRuler:
|
|||
all_labels.add(l)
|
||||
return tuple(all_labels)
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
patterns_path: Optional[Path] = None
|
||||
):
|
||||
"""Initialize the pipe for training.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Language): The current nlp object the component is part of.
|
||||
patterns_path: Path to serialized patterns.
|
||||
|
||||
DOCS (TODO): https://nightly.spacy.io/api/entityruler#initialize
|
||||
"""
|
||||
if patterns_path:
|
||||
patterns = srsly.read_jsonl(patterns_path)
|
||||
self.add_patterns(patterns)
|
||||
|
||||
|
||||
@property
|
||||
def ent_ids(self) -> Tuple[str, ...]:
|
||||
"""All entity ids present in the match patterns `id` properties
|
||||
|
@ -292,6 +315,11 @@ class EntityRuler:
|
|||
self.phrase_patterns = defaultdict(list)
|
||||
self._ent_ids = defaultdict(dict)
|
||||
|
||||
def _require_patterns(self) -> None:
|
||||
"""Raise an error if the component has no patterns."""
|
||||
if not self.patterns or list(self.patterns) == [""]:
|
||||
raise ValueError(Errors.E900.format(name=self.name))
|
||||
|
||||
def _split_label(self, label: str) -> Tuple[str, str]:
|
||||
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
nlp.resume_training(sgd=optimizer)
|
||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
||||
logger.info("Initialized pipeline components")
|
||||
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
||||
return nlp
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue