mirror of https://github.com/explosion/spaCy.git
Merge pull request #5996 from svlandeg/feature/docs-trf-examples [ci skip]
custom transformer examples
This commit is contained in:
commit
b6ee284376
|
@ -179,7 +179,7 @@ interoperates with [PyTorch](https://pytorch.org) and the
|
||||||
giving you access to thousands of pretrained models for your pipelines. There
|
giving you access to thousands of pretrained models for your pipelines. There
|
||||||
are many [great guides](http://jalammar.github.io/illustrated-transformer/) to
|
are many [great guides](http://jalammar.github.io/illustrated-transformer/) to
|
||||||
transformer models, but for practical purposes, you can simply think of them as
|
transformer models, but for practical purposes, you can simply think of them as
|
||||||
a drop-in replacement that let you achieve **higher accuracy** in exchange for
|
drop-in replacements that let you achieve **higher accuracy** in exchange for
|
||||||
**higher training and runtime costs**.
|
**higher training and runtime costs**.
|
||||||
|
|
||||||
### Setup and installation {#transformers-installation}
|
### Setup and installation {#transformers-installation}
|
||||||
|
@ -225,7 +225,7 @@ transformers as subnetworks directly, you can also use them via the
|
||||||
|
|
||||||
![The processing pipeline with the transformer component](../images/pipeline_transformer.svg)
|
![The processing pipeline with the transformer component](../images/pipeline_transformer.svg)
|
||||||
|
|
||||||
The `Transformer` component sets the
|
By default, the `Transformer` component sets the
|
||||||
[`Doc._.trf_data`](/api/transformer#custom_attributes) extension attribute,
|
[`Doc._.trf_data`](/api/transformer#custom_attributes) extension attribute,
|
||||||
which lets you access the transformers outputs at runtime.
|
which lets you access the transformers outputs at runtime.
|
||||||
|
|
||||||
|
@ -249,8 +249,8 @@ for doc in nlp.pipe(["some text", "some other text"]):
|
||||||
tokvecs = doc._.trf_data.tensors[-1]
|
tokvecs = doc._.trf_data.tensors[-1]
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also customize how the [`Transformer`](/api/transformer) component sets
|
You can customize how the [`Transformer`](/api/transformer) component sets
|
||||||
annotations onto the [`Doc`](/api/doc), by customizing the `annotation_setter`.
|
annotations onto the [`Doc`](/api/doc), by changing the `annotation_setter`.
|
||||||
This callback will be called with the raw input and output data for the whole
|
This callback will be called with the raw input and output data for the whole
|
||||||
batch, along with the batch of `Doc` objects, allowing you to implement whatever
|
batch, along with the batch of `Doc` objects, allowing you to implement whatever
|
||||||
you need. The annotation setter is called with a batch of [`Doc`](/api/doc)
|
you need. The annotation setter is called with a batch of [`Doc`](/api/doc)
|
||||||
|
@ -259,13 +259,15 @@ containing the transformers data for the batch.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def custom_annotation_setter(docs, trf_data):
|
def custom_annotation_setter(docs, trf_data):
|
||||||
# TODO:
|
doc_data = list(trf_data.doc_data)
|
||||||
...
|
for doc, data in zip(docs, doc_data):
|
||||||
|
doc._.custom_attr = data
|
||||||
|
|
||||||
nlp = spacy.load("en_core_trf_lg")
|
nlp = spacy.load("en_core_trf_lg")
|
||||||
nlp.get_pipe("transformer").annotation_setter = custom_annotation_setter
|
nlp.get_pipe("transformer").annotation_setter = custom_annotation_setter
|
||||||
doc = nlp("This is a text")
|
doc = nlp("This is a text")
|
||||||
print() # TODO:
|
assert isinstance(doc._.custom_attr, TransformerData)
|
||||||
|
print(doc._.custom_attr.tensors)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Training usage {#transformers-training}
|
### Training usage {#transformers-training}
|
||||||
|
@ -299,7 +301,7 @@ component:
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> from spacy_transformers import Transformer, TransformerModel
|
> from spacy_transformers import Transformer, TransformerModel
|
||||||
> from spacy_transformers.annotation_setters import null_annotation_setter
|
> from spacy_transformers.annotation_setters import configure_trfdata_setter
|
||||||
> from spacy_transformers.span_getters import get_doc_spans
|
> from spacy_transformers.span_getters import get_doc_spans
|
||||||
>
|
>
|
||||||
> trf = Transformer(
|
> trf = Transformer(
|
||||||
|
@ -309,7 +311,7 @@ component:
|
||||||
> get_spans=get_doc_spans,
|
> get_spans=get_doc_spans,
|
||||||
> tokenizer_config={"use_fast": True},
|
> tokenizer_config={"use_fast": True},
|
||||||
> ),
|
> ),
|
||||||
> annotation_setter=null_annotation_setter,
|
> annotation_setter=configure_trfdata_setter(),
|
||||||
> max_batch_items=4096,
|
> max_batch_items=4096,
|
||||||
> )
|
> )
|
||||||
> ```
|
> ```
|
||||||
|
@ -329,7 +331,7 @@ tokenizer_config = {"use_fast": true}
|
||||||
@span_getters = "doc_spans.v1"
|
@span_getters = "doc_spans.v1"
|
||||||
|
|
||||||
[components.transformer.annotation_setter]
|
[components.transformer.annotation_setter]
|
||||||
@annotation_setters = "spacy-transformers.null_annotation_setter.v1"
|
@annotation_setters = "spacy-transformers.trfdata_setter.v1"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -343,9 +345,9 @@ in a block starts with `@`, it's **resolved to a function** and all other
|
||||||
settings are passed to the function as arguments. In this case, `name`,
|
settings are passed to the function as arguments. In this case, `name`,
|
||||||
`tokenizer_config` and `get_spans`.
|
`tokenizer_config` and `get_spans`.
|
||||||
|
|
||||||
`get_spans` is a function that takes a batch of `Doc` object and returns lists
|
`get_spans` is a function that takes a batch of `Doc` objects and returns lists
|
||||||
of potentially overlapping `Span` objects to process by the transformer. Several
|
of potentially overlapping `Span` objects to process by the transformer. Several
|
||||||
[built-in functions](/api/transformer#span-getters) are available – for example,
|
[built-in functions](/api/transformer#span_getters) are available – for example,
|
||||||
to process the whole document or individual sentences. When the config is
|
to process the whole document or individual sentences. When the config is
|
||||||
resolved, the function is created and passed into the model as an argument.
|
resolved, the function is created and passed into the model as an argument.
|
||||||
|
|
||||||
|
@ -366,13 +368,17 @@ To change any of the settings, you can edit the `config.cfg` and re-run the
|
||||||
training. To change any of the functions, like the span getter, you can replace
|
training. To change any of the functions, like the span getter, you can replace
|
||||||
the name of the referenced function – e.g. `@span_getters = "sent_spans.v1"` to
|
the name of the referenced function – e.g. `@span_getters = "sent_spans.v1"` to
|
||||||
process sentences. You can also register your own functions using the
|
process sentences. You can also register your own functions using the
|
||||||
`span_getters` registry:
|
`span_getters` registry. For instance, the following custom function returns
|
||||||
|
`Span` objects following sentence boundaries, unless a sentence succeeds a
|
||||||
|
certain amount of tokens, in which case subsentences of at most `max_length`
|
||||||
|
tokens are returned.
|
||||||
|
|
||||||
> #### config.cfg
|
> #### config.cfg
|
||||||
>
|
>
|
||||||
> ```ini
|
> ```ini
|
||||||
> [components.transformer.model.get_spans]
|
> [components.transformer.model.get_spans]
|
||||||
> @span_getters = "custom_sent_spans"
|
> @span_getters = "custom_sent_spans"
|
||||||
|
> max_length = 25
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -380,12 +386,23 @@ process sentences. You can also register your own functions using the
|
||||||
import spacy_transformers
|
import spacy_transformers
|
||||||
|
|
||||||
@spacy_transformers.registry.span_getters("custom_sent_spans")
|
@spacy_transformers.registry.span_getters("custom_sent_spans")
|
||||||
def configure_custom_sent_spans():
|
def configure_custom_sent_spans(max_length: int):
|
||||||
# TODO: write custom example
|
def get_custom_sent_spans(docs):
|
||||||
def get_sent_spans(docs):
|
spans = []
|
||||||
return [list(doc.sents) for doc in docs]
|
for doc in docs:
|
||||||
|
spans.append([])
|
||||||
|
for sent in doc.sents:
|
||||||
|
start = 0
|
||||||
|
end = max_length
|
||||||
|
while end <= len(sent):
|
||||||
|
spans[-1].append(sent[start:end])
|
||||||
|
start += max_length
|
||||||
|
end += max_length
|
||||||
|
if start < len(sent):
|
||||||
|
spans[-1].append(sent[start:len(sent)])
|
||||||
|
return spans
|
||||||
|
|
||||||
return get_sent_spans
|
return get_custom_sent_spans
|
||||||
```
|
```
|
||||||
|
|
||||||
To resolve the config during training, spaCy needs to know about your custom
|
To resolve the config during training, spaCy needs to know about your custom
|
||||||
|
@ -412,8 +429,8 @@ The same idea applies to task models that power the **downstream components**.
|
||||||
Most of spaCy's built-in model creation functions support a `tok2vec` argument,
|
Most of spaCy's built-in model creation functions support a `tok2vec` argument,
|
||||||
which should be a Thinc layer of type ~~Model[List[Doc], List[Floats2d]]~~. This
|
which should be a Thinc layer of type ~~Model[List[Doc], List[Floats2d]]~~. This
|
||||||
is where we'll plug in our transformer model, using the
|
is where we'll plug in our transformer model, using the
|
||||||
[Tok2VecListener](/api/architectures#Tok2VecListener) layer, which sneakily
|
[TransformerListener](/api/architectures#TransformerListener) layer, which
|
||||||
delegates to the `Transformer` pipeline component.
|
sneakily delegates to the `Transformer` pipeline component.
|
||||||
|
|
||||||
```ini
|
```ini
|
||||||
### config.cfg (excerpt) {highlight="12"}
|
### config.cfg (excerpt) {highlight="12"}
|
||||||
|
@ -428,18 +445,18 @@ maxout_pieces = 3
|
||||||
use_upper = false
|
use_upper = false
|
||||||
|
|
||||||
[nlp.pipeline.ner.model.tok2vec]
|
[nlp.pipeline.ner.model.tok2vec]
|
||||||
@architectures = "spacy-transformers.Tok2VecListener.v1"
|
@architectures = "spacy-transformers.TransformerListener.v1"
|
||||||
grad_factor = 1.0
|
grad_factor = 1.0
|
||||||
|
|
||||||
[nlp.pipeline.ner.model.tok2vec.pooling]
|
[nlp.pipeline.ner.model.tok2vec.pooling]
|
||||||
@layers = "reduce_mean.v1"
|
@layers = "reduce_mean.v1"
|
||||||
```
|
```
|
||||||
|
|
||||||
The [Tok2VecListener](/api/architectures#Tok2VecListener) layer expects a
|
The [TransformerListener](/api/architectures#TransformerListener) layer expects
|
||||||
[pooling layer](https://thinc.ai/docs/api-layers#reduction-ops) as the argument
|
a [pooling layer](https://thinc.ai/docs/api-layers#reduction-ops) as the
|
||||||
`pooling`, which needs to be of type ~~Model[Ragged, Floats2d]~~. This layer
|
argument `pooling`, which needs to be of type ~~Model[Ragged, Floats2d]~~. This
|
||||||
determines how the vector for each spaCy token will be computed from the zero or
|
layer determines how the vector for each spaCy token will be computed from the
|
||||||
more source rows the token is aligned against. Here we use the
|
zero or more source rows the token is aligned against. Here we use the
|
||||||
[`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean) layer, which
|
[`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean) layer, which
|
||||||
averages the wordpiece rows. We could instead use
|
averages the wordpiece rows. We could instead use
|
||||||
[`reduce_max`](https://thinc.ai/docs/api-layers#reduce_max), or a custom
|
[`reduce_max`](https://thinc.ai/docs/api-layers#reduce_max), or a custom
|
||||||
|
@ -535,8 +552,9 @@ vectors, but combines them via summation with a smaller table of learned
|
||||||
embeddings.
|
embeddings.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from thinc.api import add, chain, remap_ids, Embed
|
from thinc.api import add, chain, remap_ids, Embed, FeatureExtractor
|
||||||
from spacy.ml.staticvectors import StaticVectors
|
from spacy.ml.staticvectors import StaticVectors
|
||||||
|
from spacy.util import registry
|
||||||
|
|
||||||
@registry.architectures("my_example.MyEmbedding.v1")
|
@registry.architectures("my_example.MyEmbedding.v1")
|
||||||
def MyCustomVectors(
|
def MyCustomVectors(
|
||||||
|
|
Loading…
Reference in New Issue