diff --git a/website/docs/usage/layers-architectures.md b/website/docs/usage/layers-architectures.md index c4b3fb9dc..7e563cb5c 100644 --- a/website/docs/usage/layers-architectures.md +++ b/website/docs/usage/layers-architectures.md @@ -613,7 +613,7 @@ get_candidates = model.attrs["get_candidates"] #### Step 2: Implementing the pipeline component {#component-rel-pipe} -To use our new relation extraction model as part of a custom component, we +To use our new relation extraction model as part of a custom component, we create a subclass of [`Pipe`](/api/pipe) that will hold the model: ```python @@ -635,15 +635,57 @@ def make_relation_extractor(nlp, name, model, labels): return RelationExtractor(nlp.vocab, model, name, labels=labels) ``` -The [`predict`](/api/pipe#predict ) function needs to be implemented for each subclass. -In our case, we can simply delegate to the internal model's +The [`predict`](/api/pipe#predict) function needs to be implemented for each +subclass. In our case, we can simply delegate to the internal model's [predict](https://thinc.ai/docs/api-model#predict) function: + ```python def predict(self, docs: Iterable[Doc]) -> Floats2d: scores = self.model.predict(docs) return self.model.ops.asarray(scores) ``` +The other method that needs to be implemented, is +[`set_annotations`](/api/pipe#set_annotations). It takes the predicted scores, +and modifies the given `Doc` object in place to hold the predictions. For our +relation extraction component, we'll store the data as a dictionary in a custom +extension attribute `doc._.rel`. As keys, we represent the candidate pair by the +start offsets of each entity, as this defines an entity uniquely within one +document. + +To interpret the scores predicted by the REL model correctly, we need to +refer to the model's `get_candidates` function that originally defined which +pairs of entities would be run through the model, so that the scores can be +related to those exact entities: + +> #### Example output +> +> ```python +> doc = nlp("Amsterdam is the capital of the Netherlands.") +> print(f"spans: {[(e.start, e.text, e.label_) for e in doc.ents]}") +> for value, rel_dict in doc._.rel.items(): +> print(f"{value}: {rel_dict}") +> ``` + +> ``` +> spans [(0, 'Amsterdam', 'LOC'), (6, 'Netherlands', 'LOC')] +> (0, 6): {'CAPITAL_OF': 0.89, 'LOCATED_IN': 0.75, 'UNRELATED': 0.002} +> (6, 0): {'CAPITAL_OF': 0.01, 'LOCATED_IN': 0.13, 'UNRELATED': 0.017} +> ``` + +```python +def set_annotations(self, docs: Iterable[Doc], rel_scores: Floats2d): + c = 0 + get_candidates = self.model.attrs["get_candidates"] + for doc in docs: + for (e1, e2) in get_candidates(doc): + offset = (e1.start, e2.start) + if offset not in doc._.rel: + doc._.rel[offset] = {} + for j, label in enumerate(self.labels): + doc._.rel[offset][label] = rel_scores[c, j] + c += 1 +```