mirror of https://github.com/explosion/spaCy.git
Add spacy.TextCatParametricAttention.v1 (#13201)
* Add spacy.TextCatParametricAttention.v1 This layer provides is a simplification of the ensemble classifier that only uses paramteric attention. We have found empirically that with a sufficient amount of training data, using the ensemble classifier with BoW does not provide significant improvement in classifier accuracy. However, plugging in a BoW classifier does reduce GPU training and inference performance substantially, since it uses a GPU-only kernel. * Fix merge fallout
This commit is contained in:
parent
7ebba86402
commit
e2a3952de5
|
@ -5,7 +5,7 @@ requires = [
|
||||||
"cymem>=2.0.2,<2.1.0",
|
"cymem>=2.0.2,<2.1.0",
|
||||||
"preshed>=3.0.2,<3.1.0",
|
"preshed>=3.0.2,<3.1.0",
|
||||||
"murmurhash>=0.28.0,<1.1.0",
|
"murmurhash>=0.28.0,<1.1.0",
|
||||||
"thinc>=8.1.8,<8.3.0",
|
"thinc>=8.2.2,<8.3.0",
|
||||||
"numpy>=1.15.0; python_version < '3.9'",
|
"numpy>=1.15.0; python_version < '3.9'",
|
||||||
"numpy>=1.25.0; python_version >= '3.9'",
|
"numpy>=1.25.0; python_version >= '3.9'",
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,7 +3,7 @@ spacy-legacy>=3.0.11,<3.1.0
|
||||||
spacy-loggers>=1.0.0,<2.0.0
|
spacy-loggers>=1.0.0,<2.0.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=8.1.8,<8.3.0
|
thinc>=8.2.2,<8.3.0
|
||||||
ml_datasets>=0.2.0,<0.3.0
|
ml_datasets>=0.2.0,<0.3.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
wasabi>=0.9.1,<1.2.0
|
wasabi>=0.9.1,<1.2.0
|
||||||
|
|
|
@ -41,7 +41,7 @@ setup_requires =
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
thinc>=8.1.8,<8.3.0
|
thinc>=8.2.2,<8.3.0
|
||||||
install_requires =
|
install_requires =
|
||||||
# Our libraries
|
# Our libraries
|
||||||
spacy-legacy>=3.0.11,<3.1.0
|
spacy-legacy>=3.0.11,<3.1.0
|
||||||
|
@ -49,7 +49,7 @@ install_requires =
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=8.1.8,<8.3.0
|
thinc>=8.2.2,<8.3.0
|
||||||
wasabi>=0.9.1,<1.2.0
|
wasabi>=0.9.1,<1.2.0
|
||||||
srsly>=2.4.3,<3.0.0
|
srsly>=2.4.3,<3.0.0
|
||||||
catalogue>=2.0.6,<2.1.0
|
catalogue>=2.0.6,<2.1.0
|
||||||
|
|
|
@ -3,12 +3,14 @@ from typing import List, Optional, Tuple, cast
|
||||||
|
|
||||||
from thinc.api import (
|
from thinc.api import (
|
||||||
Dropout,
|
Dropout,
|
||||||
|
Gelu,
|
||||||
LayerNorm,
|
LayerNorm,
|
||||||
Linear,
|
Linear,
|
||||||
Logistic,
|
Logistic,
|
||||||
Maxout,
|
Maxout,
|
||||||
Model,
|
Model,
|
||||||
ParametricAttention,
|
ParametricAttention,
|
||||||
|
ParametricAttention_v2,
|
||||||
Relu,
|
Relu,
|
||||||
Softmax,
|
Softmax,
|
||||||
SparseLinear,
|
SparseLinear,
|
||||||
|
@ -146,6 +148,9 @@ def build_text_classifier_v2(
|
||||||
linear_model: Model[List[Doc], Floats2d],
|
linear_model: Model[List[Doc], Floats2d],
|
||||||
nO: Optional[int] = None,
|
nO: Optional[int] = None,
|
||||||
) -> Model[List[Doc], Floats2d]:
|
) -> Model[List[Doc], Floats2d]:
|
||||||
|
# TODO: build the model with _build_parametric_attention_with_residual_nonlinear
|
||||||
|
# in spaCy v4. We don't do this in spaCy v3 to preserve model
|
||||||
|
# compatibility.
|
||||||
exclusive_classes = not linear_model.attrs["multi_label"]
|
exclusive_classes = not linear_model.attrs["multi_label"]
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
width = tok2vec.maybe_get_dim("nO")
|
width = tok2vec.maybe_get_dim("nO")
|
||||||
|
@ -211,6 +216,71 @@ def build_text_classifier_lowdata(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@registry.architectures("spacy.TextCatParametricAttention.v1")
|
||||||
|
def build_textcat_parametric_attention_v1(
|
||||||
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
exclusive_classes: bool,
|
||||||
|
nO: Optional[int] = None,
|
||||||
|
) -> Model[List[Doc], Floats2d]:
|
||||||
|
width = tok2vec.maybe_get_dim("nO")
|
||||||
|
parametric_attention = _build_parametric_attention_with_residual_nonlinear(
|
||||||
|
tok2vec=tok2vec,
|
||||||
|
nonlinear_layer=Maxout(nI=width, nO=width),
|
||||||
|
key_transform=Gelu(nI=width, nO=width),
|
||||||
|
)
|
||||||
|
with Model.define_operators({">>": chain}):
|
||||||
|
if exclusive_classes:
|
||||||
|
output_layer = Softmax(nO=nO)
|
||||||
|
else:
|
||||||
|
output_layer = Linear(nO=nO) >> Logistic()
|
||||||
|
model = parametric_attention >> output_layer
|
||||||
|
if model.has_dim("nO") is not False and nO is not None:
|
||||||
|
model.set_dim("nO", cast(int, nO))
|
||||||
|
model.set_ref("output_layer", output_layer)
|
||||||
|
model.attrs["multi_label"] = not exclusive_classes
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _build_parametric_attention_with_residual_nonlinear(
|
||||||
|
*,
|
||||||
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
nonlinear_layer: Model[Floats2d, Floats2d],
|
||||||
|
key_transform: Optional[Model[Floats2d, Floats2d]] = None,
|
||||||
|
) -> Model[List[Doc], Floats2d]:
|
||||||
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
|
width = tok2vec.maybe_get_dim("nO")
|
||||||
|
attention_layer = ParametricAttention_v2(nO=width, key_transform=key_transform)
|
||||||
|
norm_layer = LayerNorm(nI=width)
|
||||||
|
parametric_attention = (
|
||||||
|
tok2vec
|
||||||
|
>> list2ragged()
|
||||||
|
>> attention_layer
|
||||||
|
>> reduce_sum()
|
||||||
|
>> residual(nonlinear_layer >> norm_layer >> Dropout(0.0))
|
||||||
|
)
|
||||||
|
|
||||||
|
parametric_attention.init = _init_parametric_attention_with_residual_nonlinear
|
||||||
|
|
||||||
|
parametric_attention.set_ref("tok2vec", tok2vec)
|
||||||
|
parametric_attention.set_ref("attention_layer", attention_layer)
|
||||||
|
parametric_attention.set_ref("nonlinear_layer", nonlinear_layer)
|
||||||
|
parametric_attention.set_ref("norm_layer", norm_layer)
|
||||||
|
|
||||||
|
return parametric_attention
|
||||||
|
|
||||||
|
|
||||||
|
def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:
|
||||||
|
tok2vec_width = get_tok2vec_width(model)
|
||||||
|
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||||
|
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)
|
||||||
|
model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width)
|
||||||
|
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
|
||||||
|
model.get_ref("norm_layer").set_dim("nO", tok2vec_width)
|
||||||
|
init_chain(model, X, Y)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatReduce.v1")
|
@registry.architectures("spacy.TextCatReduce.v1")
|
||||||
def build_reduce_text_classifier(
|
def build_reduce_text_classifier(
|
||||||
tok2vec: Model,
|
tok2vec: Model,
|
||||||
|
|
|
@ -704,6 +704,9 @@ def test_overfitting_IO_multi():
|
||||||
# CNN V2 (legacy)
|
# CNN V2 (legacy)
|
||||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||||
|
# PARAMETRIC ATTENTION V1
|
||||||
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatParametricAttention.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||||
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatParametricAttention.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||||
# REDUCE V1
|
# REDUCE V1
|
||||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||||
|
|
|
@ -1056,6 +1056,44 @@ the others, but may not be as accurate, especially if texts are short.
|
||||||
|
|
||||||
</Accordion>
|
</Accordion>
|
||||||
|
|
||||||
|
### spacy.TextCatParametricAttention.v1 {id="TextCatParametricAttention"}
|
||||||
|
|
||||||
|
> #### Example Config
|
||||||
|
>
|
||||||
|
> ```ini
|
||||||
|
> [model]
|
||||||
|
> @architectures = "spacy.TextCatParametricAttention.v1"
|
||||||
|
> exclusive_classes = true
|
||||||
|
> nO = null
|
||||||
|
>
|
||||||
|
> [model.tok2vec]
|
||||||
|
> @architectures = "spacy.Tok2Vec.v2"
|
||||||
|
>
|
||||||
|
> [model.tok2vec.embed]
|
||||||
|
> @architectures = "spacy.MultiHashEmbed.v2"
|
||||||
|
> width = 64
|
||||||
|
> rows = [2000, 2000, 1000, 1000, 1000, 1000]
|
||||||
|
> attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
|
||||||
|
> include_static_vectors = false
|
||||||
|
>
|
||||||
|
> [model.tok2vec.encode]
|
||||||
|
> @architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
> width = ${model.tok2vec.embed.width}
|
||||||
|
> window_size = 1
|
||||||
|
> maxout_pieces = 3
|
||||||
|
> depth = 2
|
||||||
|
> ```
|
||||||
|
|
||||||
|
A neural network model that is built upon Tok2Vec and uses parametric attention
|
||||||
|
to attend to tokens that are relevant to text classification.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `tok2vec` | The `tok2vec` layer to build the neural network upon. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||||
|
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
|
||||||
|
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
||||||
|
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
||||||
|
|
||||||
### spacy.TextCatReduce.v1 {id="TextCatReduce"}
|
### spacy.TextCatReduce.v1 {id="TextCatReduce"}
|
||||||
|
|
||||||
> #### Example Config
|
> #### Example Config
|
||||||
|
|
Loading…
Reference in New Issue