diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index e25777a21..630057c3f 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -45,6 +45,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] default_config={ "model": DEFAULT_NEL_MODEL, "labels_discard": [], + "n_sents": 0, "incl_prior": True, "incl_context": True, "entity_vector_length": 64, @@ -62,6 +63,7 @@ def make_entity_linker( model: Model, *, labels_discard: Iterable[str], + n_sents: int, incl_prior: bool, incl_context: bool, entity_vector_length: int, @@ -73,6 +75,7 @@ def make_entity_linker( representations. Given a batch of Doc objects, it should return a single array, with one row per item in the batch. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. + n_sents (int): The number of neighbouring sentences to take into account. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. entity_vector_length (int): Size of encoding vectors in the KB. @@ -84,6 +87,7 @@ def make_entity_linker( model, name, labels_discard=labels_discard, + n_sents=n_sents, incl_prior=incl_prior, incl_context=incl_context, entity_vector_length=entity_vector_length, @@ -106,6 +110,7 @@ class EntityLinker(TrainablePipe): name: str = "entity_linker", *, labels_discard: Iterable[str], + n_sents: int, incl_prior: bool, incl_context: bool, entity_vector_length: int, @@ -118,6 +123,7 @@ class EntityLinker(TrainablePipe): name (str): The component instance name, used to add entries to the losses during training. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. + n_sents (int): The number of neighbouring sentences to take into account. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. entity_vector_length (int): Size of encoding vectors in the KB. @@ -129,17 +135,14 @@ class EntityLinker(TrainablePipe): self.vocab = vocab self.model = model self.name = name - cfg = { - "labels_discard": list(labels_discard), - "incl_prior": incl_prior, - "incl_context": incl_context, - "entity_vector_length": entity_vector_length, - } + self.labels_discard = list(labels_discard) + self.n_sents = n_sents + self.incl_prior = incl_prior + self.incl_context = incl_context self.get_candidates = get_candidates - self.cfg = dict(cfg) + self.cfg = {} self.distance = CosineDistance(normalize=False) # how many neightbour sentences to take into account - self.n_sents = cfg.get("n_sents", 0) # create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'. self.kb = empty_kb(entity_vector_length)(self.vocab) @@ -150,7 +153,6 @@ class EntityLinker(TrainablePipe): raise ValueError(Errors.E885.format(arg_type=type(kb_loader))) self.kb = kb_loader(self.vocab) - self.cfg["entity_vector_length"] = self.kb.entity_vector_length def validate_kb(self) -> None: # Raise an error if the knowledge base is not initialized. @@ -312,14 +314,13 @@ class EntityLinker(TrainablePipe): sent_doc = doc[start_token:end_token].as_doc() # currently, the context is the same for each entity in a sentence (should be refined) xp = self.model.ops.xp - if self.cfg.get("incl_context"): + if self.incl_context: sentence_encoding = self.model.predict([sent_doc])[0] sentence_encoding_t = sentence_encoding.T sentence_norm = xp.linalg.norm(sentence_encoding_t) for ent in sent.ents: entity_count += 1 - to_discard = self.cfg.get("labels_discard", []) - if to_discard and ent.label_ in to_discard: + if ent.label_ in self.labels_discard: # ignoring this entity - setting to NIL final_kb_ids.append(self.NIL) else: @@ -337,13 +338,13 @@ class EntityLinker(TrainablePipe): prior_probs = xp.asarray( [c.prior_prob for c in candidates] ) - if not self.cfg.get("incl_prior"): + if not self.incl_prior: prior_probs = xp.asarray( [0.0 for _ in candidates] ) scores = prior_probs # add in similarity from the context - if self.cfg.get("incl_context"): + if self.incl_context: entity_encodings = xp.asarray( [c.entity_vector for c in candidates] ) diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 8ba2d0d3e..348298e06 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -250,6 +250,14 @@ def test_el_pipe_configuration(nlp): assert doc[2].ent_kb_id_ == "Q2" +def test_nel_nsents(nlp): + """Test that n_sents can be set through the configuration""" + entity_linker = nlp.add_pipe("entity_linker", config={}) + assert entity_linker.n_sents == 0 + entity_linker = nlp.replace_pipe("entity_linker", "entity_linker", config={"n_sents": 2}) + assert entity_linker.n_sents == 2 + + def test_vocab_serialization(nlp): """Test that string information is retained across storage""" mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 6a21ddfaa..9af8395a6 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -83,9 +83,9 @@ def test_replace_last_pipe(nlp): def test_replace_pipe_config(nlp): nlp.add_pipe("entity_linker") nlp.add_pipe("sentencizer") - assert nlp.get_pipe("entity_linker").cfg["incl_prior"] is True + assert nlp.get_pipe("entity_linker").incl_prior is True nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False}) - assert nlp.get_pipe("entity_linker").cfg["incl_prior"] is False + assert nlp.get_pipe("entity_linker").incl_prior is False @pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")]) diff --git a/website/docs/api/entitylinker.md b/website/docs/api/entitylinker.md index a794ce632..1cc864059 100644 --- a/website/docs/api/entitylinker.md +++ b/website/docs/api/entitylinker.md @@ -31,6 +31,7 @@ architectures and their arguments and hyperparameters. > from spacy.pipeline.entity_linker import DEFAULT_NEL_MODEL > config = { > "labels_discard": [], +> "n_sents": 0, > "incl_prior": True, > "incl_context": True, > "model": DEFAULT_NEL_MODEL, @@ -43,6 +44,7 @@ architectures and their arguments and hyperparameters. | Setting | Description | | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ | +| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ | | `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ | | `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ | @@ -89,6 +91,7 @@ custom knowledge base, you should either call | `entity_vector_length` | Size of encoding vectors in the KB. ~~int~~ | | `get_candidates` | Function that generates plausible candidates for a given `Span` object. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ | | `labels_discard` | NER labels that will automatically get a `"NIL"` prediction. ~~Iterable[str]~~ | +| `n_sents` | The number of neighbouring sentences to take into account. ~~int~~ | | `incl_prior` | Whether or not to include prior probabilities from the KB in the model. ~~bool~~ | | `incl_context` | Whether or not to include the local context in the model. ~~bool~~ | @@ -247,14 +250,14 @@ pipe's entity linking model and context encoder. Delegates to > losses = entity_linker.update(examples, sgd=optimizer) > ``` -| Name | Description | -| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------- | -| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | The dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------ | +| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | The dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## EntityLinker.score {#score tag="method" new="3"}