mirror of https://github.com/explosion/spaCy.git
Support registered vectors (#12492)
* Support registered vectors * Format * Auto-fill [nlp] on load from config and from bytes/disk * Only auto-fill [nlp] * Undo all changes to Language.from_disk * Expand BaseVectors These methods are needed in various places for training and vector similarity. * isort * More linting * Only fill [nlp.vectors] * Update spacy/vocab.pyx Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Revert changes to test related to auto-filling [nlp] * Add vectors registry * Rephrase error about vocab methods for vectors * Switch to dummy implementation for BaseVectors.to_ops * Add initial draft of docs * Remove example from BaseVectors docs * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update website/docs/api/basevectors.mdx Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Fix type and lint bpemb example * Update website/docs/api/basevectors.mdx --------- Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
9ffa5d8a15
commit
0fe43f40f1
|
@ -26,6 +26,9 @@ batch_size = 1000
|
||||||
[nlp.tokenizer]
|
[nlp.tokenizer]
|
||||||
@tokenizers = "spacy.Tokenizer.v1"
|
@tokenizers = "spacy.Tokenizer.v1"
|
||||||
|
|
||||||
|
[nlp.vectors]
|
||||||
|
@vectors = "spacy.Vectors.v1"
|
||||||
|
|
||||||
# The pipeline components and their models
|
# The pipeline components and their models
|
||||||
[components]
|
[components]
|
||||||
|
|
||||||
|
|
|
@ -553,6 +553,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
"during training, make sure to include it in 'annotating components'")
|
"during training, make sure to include it in 'annotating components'")
|
||||||
|
|
||||||
# New errors added in v3.x
|
# New errors added in v3.x
|
||||||
|
E849 = ("The vocab only supports {method} for vectors of type "
|
||||||
|
"spacy.vectors.Vectors, not {vectors_type}.")
|
||||||
E850 = ("The PretrainVectors objective currently only supports default or "
|
E850 = ("The PretrainVectors objective currently only supports default or "
|
||||||
"floret vectors, not {mode} vectors.")
|
"floret vectors, not {mode} vectors.")
|
||||||
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||||
|
|
|
@ -65,6 +65,7 @@ from .util import (
|
||||||
registry,
|
registry,
|
||||||
warn_if_jupyter_cupy,
|
warn_if_jupyter_cupy,
|
||||||
)
|
)
|
||||||
|
from .vectors import BaseVectors
|
||||||
from .vocab import Vocab, create_vocab
|
from .vocab import Vocab, create_vocab
|
||||||
|
|
||||||
PipeCallable = Callable[[Doc], Doc]
|
PipeCallable = Callable[[Doc], Doc]
|
||||||
|
@ -158,6 +159,7 @@ class Language:
|
||||||
max_length: int = 10**6,
|
max_length: int = 10**6,
|
||||||
meta: Dict[str, Any] = {},
|
meta: Dict[str, Any] = {},
|
||||||
create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None,
|
create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None,
|
||||||
|
create_vectors: Optional[Callable[["Vocab"], BaseVectors]] = None,
|
||||||
batch_size: int = 1000,
|
batch_size: int = 1000,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -198,6 +200,10 @@ class Language:
|
||||||
if vocab is True:
|
if vocab is True:
|
||||||
vectors_name = meta.get("vectors", {}).get("name")
|
vectors_name = meta.get("vectors", {}).get("name")
|
||||||
vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name)
|
vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name)
|
||||||
|
if not create_vectors:
|
||||||
|
vectors_cfg = {"vectors": self._config["nlp"]["vectors"]}
|
||||||
|
create_vectors = registry.resolve(vectors_cfg)["vectors"]
|
||||||
|
vocab.vectors = create_vectors(vocab)
|
||||||
else:
|
else:
|
||||||
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
||||||
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
||||||
|
@ -1765,6 +1771,10 @@ class Language:
|
||||||
).merge(config)
|
).merge(config)
|
||||||
if "nlp" not in config:
|
if "nlp" not in config:
|
||||||
raise ValueError(Errors.E985.format(config=config))
|
raise ValueError(Errors.E985.format(config=config))
|
||||||
|
# fill in [nlp.vectors] if not present (as a narrower alternative to
|
||||||
|
# auto-filling [nlp] from the default config)
|
||||||
|
if "vectors" not in config["nlp"]:
|
||||||
|
config["nlp"]["vectors"] = {"@vectors": "spacy.Vectors.v1"}
|
||||||
config_lang = config["nlp"].get("lang")
|
config_lang = config["nlp"].get("lang")
|
||||||
if config_lang is not None and config_lang != cls.lang:
|
if config_lang is not None and config_lang != cls.lang:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -1796,6 +1806,7 @@ class Language:
|
||||||
filled["nlp"], validate=validate, schema=ConfigSchemaNlp
|
filled["nlp"], validate=validate, schema=ConfigSchemaNlp
|
||||||
)
|
)
|
||||||
create_tokenizer = resolved_nlp["tokenizer"]
|
create_tokenizer = resolved_nlp["tokenizer"]
|
||||||
|
create_vectors = resolved_nlp["vectors"]
|
||||||
before_creation = resolved_nlp["before_creation"]
|
before_creation = resolved_nlp["before_creation"]
|
||||||
after_creation = resolved_nlp["after_creation"]
|
after_creation = resolved_nlp["after_creation"]
|
||||||
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
|
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
|
||||||
|
@ -1816,7 +1827,12 @@ class Language:
|
||||||
# inside stuff like the spacy train function. If we loaded them here,
|
# inside stuff like the spacy train function. If we loaded them here,
|
||||||
# then we would load them twice at runtime: once when we make from config,
|
# then we would load them twice at runtime: once when we make from config,
|
||||||
# and then again when we load from disk.
|
# and then again when we load from disk.
|
||||||
nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, meta=meta)
|
nlp = lang_cls(
|
||||||
|
vocab=vocab,
|
||||||
|
create_tokenizer=create_tokenizer,
|
||||||
|
create_vectors=create_vectors,
|
||||||
|
meta=meta,
|
||||||
|
)
|
||||||
if after_creation is not None:
|
if after_creation is not None:
|
||||||
nlp = after_creation(nlp)
|
nlp = after_creation(nlp)
|
||||||
if not isinstance(nlp, cls):
|
if not isinstance(nlp, cls):
|
||||||
|
|
|
@ -9,7 +9,7 @@ from thinc.util import partial
|
||||||
from ..attrs import ORTH
|
from ..attrs import ORTH
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..vectors import Mode
|
from ..vectors import Mode, Vectors
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,11 +48,14 @@ def forward(
|
||||||
key_attr: int = getattr(vocab.vectors, "attr", ORTH)
|
key_attr: int = getattr(vocab.vectors, "attr", ORTH)
|
||||||
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
||||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||||
if vocab.vectors.mode == Mode.default:
|
if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
|
||||||
V = model.ops.asarray(vocab.vectors.data)
|
V = model.ops.asarray(vocab.vectors.data)
|
||||||
rows = vocab.vectors.find(keys=keys)
|
rows = vocab.vectors.find(keys=keys)
|
||||||
V = model.ops.as_contig(V[rows])
|
V = model.ops.as_contig(V[rows])
|
||||||
elif vocab.vectors.mode == Mode.floret:
|
elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret:
|
||||||
|
V = vocab.vectors.get_batch(keys)
|
||||||
|
V = model.ops.as_contig(V)
|
||||||
|
elif hasattr(vocab.vectors, "get_batch"):
|
||||||
V = vocab.vectors.get_batch(keys)
|
V = vocab.vectors.get_batch(keys)
|
||||||
V = model.ops.as_contig(V)
|
V = model.ops.as_contig(V)
|
||||||
else:
|
else:
|
||||||
|
@ -61,7 +64,7 @@ def forward(
|
||||||
vectors_data = model.ops.gemm(V, W, trans2=True)
|
vectors_data = model.ops.gemm(V, W, trans2=True)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise RuntimeError(Errors.E896)
|
raise RuntimeError(Errors.E896)
|
||||||
if vocab.vectors.mode == Mode.default:
|
if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
|
||||||
# Convert negative indices to 0-vectors
|
# Convert negative indices to 0-vectors
|
||||||
# TODO: more options for UNK tokens
|
# TODO: more options for UNK tokens
|
||||||
vectors_data[rows < 0] = 0
|
vectors_data[rows < 0] = 0
|
||||||
|
|
|
@ -397,6 +397,7 @@ class ConfigSchemaNlp(BaseModel):
|
||||||
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
|
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
|
||||||
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
|
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
|
||||||
batch_size: Optional[int] = Field(..., title="Default batch size")
|
batch_size: Optional[int] = Field(..., title="Default batch size")
|
||||||
|
vectors: Callable = Field(..., title="Vectors implementation")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
|
@ -118,6 +118,7 @@ class registry(thinc.registry):
|
||||||
augmenters = catalogue.create("spacy", "augmenters", entry_points=True)
|
augmenters = catalogue.create("spacy", "augmenters", entry_points=True)
|
||||||
loggers = catalogue.create("spacy", "loggers", entry_points=True)
|
loggers = catalogue.create("spacy", "loggers", entry_points=True)
|
||||||
scorers = catalogue.create("spacy", "scorers", entry_points=True)
|
scorers = catalogue.create("spacy", "scorers", entry_points=True)
|
||||||
|
vectors = catalogue.create("spacy", "vectors", entry_points=True)
|
||||||
# These are factories registered via third-party packages and the
|
# These are factories registered via third-party packages and the
|
||||||
# spacy_factories entry point. This registry only exists so we can easily
|
# spacy_factories entry point. This registry only exists so we can easily
|
||||||
# load them via the entry points. The "true" factories are added via the
|
# load them via the entry points. The "true" factories are added via the
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from cython.operator cimport dereference as deref
|
from cython.operator cimport dereference as deref
|
||||||
from libc.stdint cimport uint32_t, uint64_t
|
from libc.stdint cimport uint32_t, uint64_t
|
||||||
from libcpp.set cimport set as cppset
|
from libcpp.set cimport set as cppset
|
||||||
|
@ -5,7 +8,8 @@ from murmurhash.mrmr cimport hash128_x64
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import cast
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Union, cast
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -21,6 +25,9 @@ from .attrs import IDS
|
||||||
from .errors import Errors, Warnings
|
from .errors import Errors, Warnings
|
||||||
from .strings import get_string_id
|
from .strings import get_string_id
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .vocab import Vocab # noqa: F401 # no-cython-lint
|
||||||
|
|
||||||
|
|
||||||
def unpickle_vectors(bytes_data):
|
def unpickle_vectors(bytes_data):
|
||||||
return Vectors().from_bytes(bytes_data)
|
return Vectors().from_bytes(bytes_data)
|
||||||
|
@ -35,7 +42,71 @@ class Mode(str, Enum):
|
||||||
return list(cls.__members__.keys())
|
return list(cls.__members__.keys())
|
||||||
|
|
||||||
|
|
||||||
cdef class Vectors:
|
cdef class BaseVectors:
|
||||||
|
def __init__(self, *, strings=None):
|
||||||
|
# Make sure abstract BaseVectors is not instantiated.
|
||||||
|
if self.__class__ == BaseVectors:
|
||||||
|
raise TypeError(
|
||||||
|
Errors.E1046.format(cls_name=self.__class__.__name__)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_full(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_batch(self, keys):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vectors_length(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add(self, key, *, vector=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_ops(self, ops: Ops):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# add dummy methods for to_bytes, from_bytes, to_disk and from_disk to
|
||||||
|
# allow serialization
|
||||||
|
def to_bytes(self, **kwargs):
|
||||||
|
return b""
|
||||||
|
|
||||||
|
def from_bytes(self, data: bytes, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_disk(self, path: Union[str, Path], **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def from_disk(self, path: Union[str, Path], **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@util.registry.vectors("spacy.Vectors.v1")
|
||||||
|
def create_mode_vectors() -> Callable[["Vocab"], BaseVectors]:
|
||||||
|
def vectors_factory(vocab: "Vocab") -> BaseVectors:
|
||||||
|
return Vectors(strings=vocab.strings)
|
||||||
|
|
||||||
|
return vectors_factory
|
||||||
|
|
||||||
|
|
||||||
|
cdef class Vectors(BaseVectors):
|
||||||
"""Store, save and load word vectors.
|
"""Store, save and load word vectors.
|
||||||
|
|
||||||
Vectors data is kept in the vectors.data attribute, which should be an
|
Vectors data is kept in the vectors.data attribute, which should be an
|
||||||
|
|
|
@ -94,6 +94,7 @@ cdef class Vocab:
|
||||||
return self._vectors
|
return self._vectors
|
||||||
|
|
||||||
def __set__(self, vectors):
|
def __set__(self, vectors):
|
||||||
|
if hasattr(vectors, "strings"):
|
||||||
for s in vectors.strings:
|
for s in vectors.strings:
|
||||||
self.strings.add(s)
|
self.strings.add(s)
|
||||||
self._vectors = vectors
|
self._vectors = vectors
|
||||||
|
@ -193,7 +194,7 @@ cdef class Vocab:
|
||||||
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
|
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
|
||||||
lex.orth = self.strings.add(string)
|
lex.orth = self.strings.add(string)
|
||||||
lex.length = len(string)
|
lex.length = len(string)
|
||||||
if self.vectors is not None:
|
if self.vectors is not None and hasattr(self.vectors, "key2row"):
|
||||||
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
|
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
|
||||||
else:
|
else:
|
||||||
lex.id = OOV_RANK
|
lex.id = OOV_RANK
|
||||||
|
@ -289,12 +290,17 @@ cdef class Vocab:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vectors_length(self):
|
def vectors_length(self):
|
||||||
|
if hasattr(self.vectors, "shape"):
|
||||||
return self.vectors.shape[1]
|
return self.vectors.shape[1]
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
def reset_vectors(self, *, width=None, shape=None):
|
def reset_vectors(self, *, width=None, shape=None):
|
||||||
"""Drop the current vector table. Because all vectors must be the same
|
"""Drop the current vector table. Because all vectors must be the same
|
||||||
width, you have to call this to change the size of the vectors.
|
width, you have to call this to change the size of the vectors.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(self.vectors, Vectors):
|
||||||
|
raise ValueError(Errors.E849.format(method="reset_vectors", vectors_type=type(self.vectors)))
|
||||||
if width is not None and shape is not None:
|
if width is not None and shape is not None:
|
||||||
raise ValueError(Errors.E065.format(width=width, shape=shape))
|
raise ValueError(Errors.E065.format(width=width, shape=shape))
|
||||||
elif shape is not None:
|
elif shape is not None:
|
||||||
|
@ -304,6 +310,8 @@ cdef class Vocab:
|
||||||
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
|
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
|
||||||
|
|
||||||
def deduplicate_vectors(self):
|
def deduplicate_vectors(self):
|
||||||
|
if not isinstance(self.vectors, Vectors):
|
||||||
|
raise ValueError(Errors.E849.format(method="deduplicate_vectors", vectors_type=type(self.vectors)))
|
||||||
if self.vectors.mode != VectorsMode.default:
|
if self.vectors.mode != VectorsMode.default:
|
||||||
raise ValueError(Errors.E858.format(
|
raise ValueError(Errors.E858.format(
|
||||||
mode=self.vectors.mode,
|
mode=self.vectors.mode,
|
||||||
|
@ -357,6 +365,8 @@ cdef class Vocab:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/vocab#prune_vectors
|
DOCS: https://spacy.io/api/vocab#prune_vectors
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(self.vectors, Vectors):
|
||||||
|
raise ValueError(Errors.E849.format(method="prune_vectors", vectors_type=type(self.vectors)))
|
||||||
if self.vectors.mode != VectorsMode.default:
|
if self.vectors.mode != VectorsMode.default:
|
||||||
raise ValueError(Errors.E858.format(
|
raise ValueError(Errors.E858.format(
|
||||||
mode=self.vectors.mode,
|
mode=self.vectors.mode,
|
||||||
|
|
|
@ -0,0 +1,143 @@
|
||||||
|
---
|
||||||
|
title: BaseVectors
|
||||||
|
teaser: Abstract class for word vectors
|
||||||
|
tag: class
|
||||||
|
source: spacy/vectors.pyx
|
||||||
|
version: 3.7
|
||||||
|
---
|
||||||
|
|
||||||
|
`BaseVectors` is an abstract class to support the development of custom vectors
|
||||||
|
implementations.
|
||||||
|
|
||||||
|
For use in training with [`StaticVectors`](/api/architectures#staticvectors),
|
||||||
|
`get_batch` must be implemented. For improved performance, use efficient
|
||||||
|
batching in `get_batch` and implement `to_ops` to copy the vector data to the
|
||||||
|
current device. See an example custom implementation for
|
||||||
|
[BPEmb subword embeddings](/usage/embeddings-transformers#custom-vectors).
|
||||||
|
|
||||||
|
## BaseVectors.\_\_init\_\_ {id="init",tag="method"}
|
||||||
|
|
||||||
|
Create a new vector store.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `strings` | The string store. A new string store is created if one is not provided. Defaults to `None`. ~~Optional[StringStore]~~ |
|
||||||
|
|
||||||
|
## BaseVectors.\_\_getitem\_\_ {id="getitem",tag="method"}
|
||||||
|
|
||||||
|
Get a vector by key. If the key is not found in the table, a `KeyError` should
|
||||||
|
be raised.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ---------------------------------------------------------------- |
|
||||||
|
| `key` | The key to get the vector for. ~~Union[int, str]~~ |
|
||||||
|
| **RETURNS** | The vector for the key. ~~numpy.ndarray[ndim=1, dtype=float32]~~ |
|
||||||
|
|
||||||
|
## BaseVectors.\_\_len\_\_ {id="len",tag="method"}
|
||||||
|
|
||||||
|
Return the number of vectors in the table.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ------------------------------------------- |
|
||||||
|
| **RETURNS** | The number of vectors in the table. ~~int~~ |
|
||||||
|
|
||||||
|
## BaseVectors.\_\_contains\_\_ {id="contains",tag="method"}
|
||||||
|
|
||||||
|
Check whether there is a vector entry for the given key.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | -------------------------------------------- |
|
||||||
|
| `key` | The key to check. ~~int~~ |
|
||||||
|
| **RETURNS** | Whether the key has a vector entry. ~~bool~~ |
|
||||||
|
|
||||||
|
## BaseVectors.add {id="add",tag="method"}
|
||||||
|
|
||||||
|
Add a key to the table, if possible. If no keys can be added, return `-1`.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ----------------------------------------------------------------------------------- |
|
||||||
|
| `key` | The key to add. ~~Union[str, int]~~ |
|
||||||
|
| **RETURNS** | The row the vector was added to, or `-1` if the operation is not supported. ~~int~~ |
|
||||||
|
|
||||||
|
## BaseVectors.shape {id="shape",tag="property"}
|
||||||
|
|
||||||
|
Get `(rows, dims)` tuples of number of rows and number of dimensions in the
|
||||||
|
vector table.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ------------------------------------------ |
|
||||||
|
| **RETURNS** | A `(rows, dims)` pair. ~~Tuple[int, int]~~ |
|
||||||
|
|
||||||
|
## BaseVectors.size {id="size",tag="property"}
|
||||||
|
|
||||||
|
The vector size, i.e. `rows * dims`.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ------------------------ |
|
||||||
|
| **RETURNS** | The vector size. ~~int~~ |
|
||||||
|
|
||||||
|
## BaseVectors.is_full {id="is_full",tag="property"}
|
||||||
|
|
||||||
|
Whether the vectors table is full and no slots are available for new keys.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ------------------------------------------- |
|
||||||
|
| **RETURNS** | Whether the vectors table is full. ~~bool~~ |
|
||||||
|
|
||||||
|
## BaseVectors.get_batch {id="get_batch",tag="method",version="3.2"}
|
||||||
|
|
||||||
|
Get the vectors for the provided keys efficiently as a batch. Required to use
|
||||||
|
the vectors with [`StaticVectors`](/api/architectures#StaticVectors) for
|
||||||
|
training.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ------ | --------------------------------------- |
|
||||||
|
| `keys` | The keys. ~~Iterable[Union[int, str]]~~ |
|
||||||
|
|
||||||
|
## BaseVectors.to_ops {id="to_ops",tag="method"}
|
||||||
|
|
||||||
|
Dummy method. Implement this to change the embedding matrix to use different
|
||||||
|
Thinc ops.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----- | -------------------------------------------------------- |
|
||||||
|
| `ops` | The Thinc ops to switch the embedding matrix to. ~~Ops~~ |
|
||||||
|
|
||||||
|
## BaseVectors.to_disk {id="to_disk",tag="method"}
|
||||||
|
|
||||||
|
Dummy method to allow serialization. Implement to save vector data with the
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ------ | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| `path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||||
|
|
||||||
|
## BaseVectors.from_disk {id="from_disk",tag="method"}
|
||||||
|
|
||||||
|
Dummy method to allow serialization. Implement to load vector data from a saved
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ----------------------------------------------------------------------------------------------- |
|
||||||
|
| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||||
|
| **RETURNS** | The modified vectors object. ~~BaseVectors~~ |
|
||||||
|
|
||||||
|
## BaseVectors.to_bytes {id="to_bytes",tag="method"}
|
||||||
|
|
||||||
|
Dummy method to allow serialization. Implement to serialize vector data to a
|
||||||
|
binary string.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ---------------------------------------------------- |
|
||||||
|
| **RETURNS** | The serialized form of the vectors object. ~~bytes~~ |
|
||||||
|
|
||||||
|
## BaseVectors.from_bytes {id="from_bytes",tag="method"}
|
||||||
|
|
||||||
|
Dummy method to allow serialization. Implement to load vector data from a binary
|
||||||
|
string.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ----------------------------------- |
|
||||||
|
| `data` | The data to load from. ~~bytes~~ |
|
||||||
|
| **RETURNS** | The vectors object. ~~BaseVectors~~ |
|
|
@ -297,10 +297,9 @@ The vector size, i.e. `rows * dims`.
|
||||||
|
|
||||||
## Vectors.is_full {id="is_full",tag="property"}
|
## Vectors.is_full {id="is_full",tag="property"}
|
||||||
|
|
||||||
Whether the vectors table is full and has no slots are available for new keys.
|
Whether the vectors table is full and no slots are available for new keys. If a
|
||||||
If a table is full, it can be resized using
|
table is full, it can be resized using [`Vectors.resize`](/api/vectors#resize).
|
||||||
[`Vectors.resize`](/api/vectors#resize). In `floret` mode, the table is always
|
In `floret` mode, the table is always full and cannot be resized.
|
||||||
full and cannot be resized.
|
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
@ -441,7 +440,7 @@ Load state from a binary string.
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> fron spacy.vectors import Vectors
|
> from spacy.vectors import Vectors
|
||||||
> vectors_bytes = vectors.to_bytes()
|
> vectors_bytes = vectors.to_bytes()
|
||||||
> new_vectors = Vectors(StringStore())
|
> new_vectors = Vectors(StringStore())
|
||||||
> new_vectors.from_bytes(vectors_bytes)
|
> new_vectors.from_bytes(vectors_bytes)
|
||||||
|
|
|
@ -632,6 +632,165 @@ def MyCustomVectors(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Creating a custom vectors implementation {id="custom-vectors",version="3.7"}
|
||||||
|
|
||||||
|
You can specify a custom registered vectors class under `[nlp.vectors]` in order
|
||||||
|
to use static vectors in formats other than the ones supported by
|
||||||
|
[`Vectors`](/api/vectors). Extend the abstract [`BaseVectors`](/api/basevectors)
|
||||||
|
class to implement your custom vectors.
|
||||||
|
|
||||||
|
As an example, the following `BPEmbVectors` class implements support for
|
||||||
|
[BPEmb subword embeddings](https://bpemb.h-its.org/):
|
||||||
|
|
||||||
|
```python
|
||||||
|
# requires: pip install bpemb
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional, cast
|
||||||
|
|
||||||
|
from bpemb import BPEmb
|
||||||
|
from thinc.api import Ops, get_current_ops
|
||||||
|
from thinc.backends import get_array_ops
|
||||||
|
from thinc.types import Floats2d
|
||||||
|
|
||||||
|
from spacy.strings import StringStore
|
||||||
|
from spacy.util import registry
|
||||||
|
from spacy.vectors import BaseVectors
|
||||||
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
|
|
||||||
|
class BPEmbVectors(BaseVectors):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
strings: Optional[StringStore] = None,
|
||||||
|
lang: Optional[str] = None,
|
||||||
|
vs: Optional[int] = None,
|
||||||
|
dim: Optional[int] = None,
|
||||||
|
cache_dir: Optional[Path] = None,
|
||||||
|
encode_extra_options: Optional[str] = None,
|
||||||
|
model_file: Optional[Path] = None,
|
||||||
|
emb_file: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
kwargs = {}
|
||||||
|
if lang is not None:
|
||||||
|
kwargs["lang"] = lang
|
||||||
|
if vs is not None:
|
||||||
|
kwargs["vs"] = vs
|
||||||
|
if dim is not None:
|
||||||
|
kwargs["dim"] = dim
|
||||||
|
if cache_dir is not None:
|
||||||
|
kwargs["cache_dir"] = cache_dir
|
||||||
|
if encode_extra_options is not None:
|
||||||
|
kwargs["encode_extra_options"] = encode_extra_options
|
||||||
|
if model_file is not None:
|
||||||
|
kwargs["model_file"] = model_file
|
||||||
|
if emb_file is not None:
|
||||||
|
kwargs["emb_file"] = emb_file
|
||||||
|
self.bpemb = BPEmb(**kwargs)
|
||||||
|
self.strings = strings
|
||||||
|
self.name = repr(self.bpemb)
|
||||||
|
self.n_keys = -1
|
||||||
|
self.mode = "BPEmb"
|
||||||
|
self.to_ops(get_current_ops())
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_full(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def add(self, key, *, vector=None, row=None):
|
||||||
|
warnings.warn(
|
||||||
|
(
|
||||||
|
"Skipping BPEmbVectors.add: the bpemb vector table cannot be "
|
||||||
|
"modified. Vectors are calculated from bytepieces."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.get_batch([key])[0]
|
||||||
|
|
||||||
|
def get_batch(self, keys):
|
||||||
|
keys = [self.strings.as_string(key) for key in keys]
|
||||||
|
bp_ids = self.bpemb.encode_ids(keys)
|
||||||
|
ops = get_array_ops(self.bpemb.emb.vectors)
|
||||||
|
indices = ops.asarray(ops.xp.hstack(bp_ids), dtype="int32")
|
||||||
|
lengths = ops.asarray([len(x) for x in bp_ids], dtype="int32")
|
||||||
|
vecs = ops.reduce_mean(cast(Floats2d, self.bpemb.emb.vectors[indices]), lengths)
|
||||||
|
return vecs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self.bpemb.vectors.shape
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vectors_length(self):
|
||||||
|
return self.shape[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self):
|
||||||
|
return self.bpemb.vectors.size
|
||||||
|
|
||||||
|
def to_ops(self, ops: Ops):
|
||||||
|
self.bpemb.emb.vectors = ops.asarray(self.bpemb.emb.vectors)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.vectors("BPEmbVectors.v1")
|
||||||
|
def create_bpemb_vectors(
|
||||||
|
lang: Optional[str] = "multi",
|
||||||
|
vs: Optional[int] = None,
|
||||||
|
dim: Optional[int] = None,
|
||||||
|
cache_dir: Optional[Path] = None,
|
||||||
|
encode_extra_options: Optional[str] = None,
|
||||||
|
model_file: Optional[Path] = None,
|
||||||
|
emb_file: Optional[Path] = None,
|
||||||
|
) -> Callable[[Vocab], BPEmbVectors]:
|
||||||
|
def bpemb_vectors_factory(vocab: Vocab) -> BPEmbVectors:
|
||||||
|
return BPEmbVectors(
|
||||||
|
strings=vocab.strings,
|
||||||
|
lang=lang,
|
||||||
|
vs=vs,
|
||||||
|
dim=dim,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
encode_extra_options=encode_extra_options,
|
||||||
|
model_file=model_file,
|
||||||
|
emb_file=emb_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
return bpemb_vectors_factory
|
||||||
|
```
|
||||||
|
|
||||||
|
<Infobox variant="warning">
|
||||||
|
|
||||||
|
Note that the serialization methods are not implemented, so the embeddings are
|
||||||
|
loaded from your local cache or downloaded by `BPEmb` each time the pipeline is
|
||||||
|
loaded.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
To use this in your pipeline, specify this registered function under
|
||||||
|
`[nlp.vectors]` in your config:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
[nlp.vectors]
|
||||||
|
@vectors = "BPEmbVectors.v1"
|
||||||
|
lang = "en"
|
||||||
|
```
|
||||||
|
|
||||||
|
Or specify it when creating a blank pipeline:
|
||||||
|
|
||||||
|
```python
|
||||||
|
nlp = spacy.blank("en", config={"nlp.vectors": {"@vectors": "BPEmbVectors.v1", "lang": "en"}})
|
||||||
|
```
|
||||||
|
|
||||||
|
Remember to include this code with `--code` when using
|
||||||
|
[`spacy train`](/api/cli#train) and [`spacy package`](/api/cli#package).
|
||||||
|
|
||||||
## Pretraining {id="pretraining"}
|
## Pretraining {id="pretraining"}
|
||||||
|
|
||||||
The [`spacy pretrain`](/api/cli#pretrain) command lets you initialize your
|
The [`spacy pretrain`](/api/cli#pretrain) command lets you initialize your
|
||||||
|
|
|
@ -131,6 +131,7 @@
|
||||||
"label": "Other",
|
"label": "Other",
|
||||||
"items": [
|
"items": [
|
||||||
{ "text": "Attributes", "url": "/api/attributes" },
|
{ "text": "Attributes", "url": "/api/attributes" },
|
||||||
|
{ "text": "BaseVectors", "url": "/api/basevectors" },
|
||||||
{ "text": "Corpus", "url": "/api/corpus" },
|
{ "text": "Corpus", "url": "/api/corpus" },
|
||||||
{ "text": "InMemoryLookupKB", "url": "/api/inmemorylookupkb" },
|
{ "text": "InMemoryLookupKB", "url": "/api/inmemorylookupkb" },
|
||||||
{ "text": "KnowledgeBase", "url": "/api/kb" },
|
{ "text": "KnowledgeBase", "url": "/api/kb" },
|
||||||
|
|
Loading…
Reference in New Issue