set_annotations explanation

This commit is contained in:
svlandeg 2020-10-04 14:56:48 +02:00
parent 9f40d963fd
commit b0463fbf75
1 changed files with 45 additions and 3 deletions

View File

@ -613,7 +613,7 @@ get_candidates = model.attrs["get_candidates"]
#### Step 2: Implementing the pipeline component {#component-rel-pipe}
To use our new relation extraction model as part of a custom component, we
To use our new relation extraction model as part of a custom component, we
create a subclass of [`Pipe`](/api/pipe) that will hold the model:
```python
@ -635,15 +635,57 @@ def make_relation_extractor(nlp, name, model, labels):
return RelationExtractor(nlp.vocab, model, name, labels=labels)
```
The [`predict`](/api/pipe#predict ) function needs to be implemented for each subclass.
In our case, we can simply delegate to the internal model's
The [`predict`](/api/pipe#predict) function needs to be implemented for each
subclass. In our case, we can simply delegate to the internal model's
[predict](https://thinc.ai/docs/api-model#predict) function:
```python
def predict(self, docs: Iterable[Doc]) -> Floats2d:
scores = self.model.predict(docs)
return self.model.ops.asarray(scores)
```
The other method that needs to be implemented, is
[`set_annotations`](/api/pipe#set_annotations). It takes the predicted scores,
and modifies the given `Doc` object in place to hold the predictions. For our
relation extraction component, we'll store the data as a dictionary in a custom
extension attribute `doc._.rel`. As keys, we represent the candidate pair by the
start offsets of each entity, as this defines an entity uniquely within one
document.
To interpret the scores predicted by the REL model correctly, we need to
refer to the model's `get_candidates` function that originally defined which
pairs of entities would be run through the model, so that the scores can be
related to those exact entities:
> #### Example output
>
> ```python
> doc = nlp("Amsterdam is the capital of the Netherlands.")
> print(f"spans: {[(e.start, e.text, e.label_) for e in doc.ents]}")
> for value, rel_dict in doc._.rel.items():
> print(f"{value}: {rel_dict}")
> ```
> ```
> spans [(0, 'Amsterdam', 'LOC'), (6, 'Netherlands', 'LOC')]
> (0, 6): {'CAPITAL_OF': 0.89, 'LOCATED_IN': 0.75, 'UNRELATED': 0.002}
> (6, 0): {'CAPITAL_OF': 0.01, 'LOCATED_IN': 0.13, 'UNRELATED': 0.017}
> ```
```python
def set_annotations(self, docs: Iterable[Doc], rel_scores: Floats2d):
c = 0
get_candidates = self.model.attrs["get_candidates"]
for doc in docs:
for (e1, e2) in get_candidates(doc):
offset = (e1.start, e2.start)
if offset not in doc._.rel:
doc._.rel[offset] = {}
for j, label in enumerate(self.labels):
doc._.rel[offset][label] = rel_scores[c, j]
c += 1
```