From 0fe43f40f1390092dc265c9f2b2cef58ae06cc58 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 1 Aug 2023 15:46:08 +0200 Subject: [PATCH] Support registered vectors (#12492) * Support registered vectors * Format * Auto-fill [nlp] on load from config and from bytes/disk * Only auto-fill [nlp] * Undo all changes to Language.from_disk * Expand BaseVectors These methods are needed in various places for training and vector similarity. * isort * More linting * Only fill [nlp.vectors] * Update spacy/vocab.pyx Co-authored-by: Sofie Van Landeghem * Revert changes to test related to auto-filling [nlp] * Add vectors registry * Rephrase error about vocab methods for vectors * Switch to dummy implementation for BaseVectors.to_ops * Add initial draft of docs * Remove example from BaseVectors docs * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem * Update website/docs/api/basevectors.mdx Co-authored-by: Sofie Van Landeghem * Fix type and lint bpemb example * Update website/docs/api/basevectors.mdx --------- Co-authored-by: Sofie Van Landeghem --- spacy/default_config.cfg | 3 + spacy/errors.py | 2 + spacy/language.py | 18 +- spacy/ml/staticvectors.py | 11 +- spacy/schemas.py | 1 + spacy/util.py | 1 + spacy/vectors.pyx | 75 ++++++++- spacy/vocab.pyx | 18 +- website/docs/api/basevectors.mdx | 143 ++++++++++++++++ website/docs/api/vectors.mdx | 9 +- .../docs/usage/embeddings-transformers.mdx | 159 ++++++++++++++++++ website/meta/sidebars.json | 1 + 12 files changed, 425 insertions(+), 16 deletions(-) create mode 100644 website/docs/api/basevectors.mdx diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg index 694fb732f..b005eef40 100644 --- a/spacy/default_config.cfg +++ b/spacy/default_config.cfg @@ -26,6 +26,9 @@ batch_size = 1000 [nlp.tokenizer] @tokenizers = "spacy.Tokenizer.v1" +[nlp.vectors] +@vectors = "spacy.Vectors.v1" + # The pipeline components and their models [components] diff --git a/spacy/errors.py b/spacy/errors.py index 225cb9c86..14ec669a3 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -553,6 +553,8 @@ class Errors(metaclass=ErrorsWithCodes): "during training, make sure to include it in 'annotating components'") # New errors added in v3.x + E849 = ("The vocab only supports {method} for vectors of type " + "spacy.vectors.Vectors, not {vectors_type}.") E850 = ("The PretrainVectors objective currently only supports default or " "floret vectors, not {mode} vectors.") E851 = ("The 'textcat' component labels should only have values of 0 or 1, " diff --git a/spacy/language.py b/spacy/language.py index b144b2c32..26152b90a 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -65,6 +65,7 @@ from .util import ( registry, warn_if_jupyter_cupy, ) +from .vectors import BaseVectors from .vocab import Vocab, create_vocab PipeCallable = Callable[[Doc], Doc] @@ -158,6 +159,7 @@ class Language: max_length: int = 10**6, meta: Dict[str, Any] = {}, create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None, + create_vectors: Optional[Callable[["Vocab"], BaseVectors]] = None, batch_size: int = 1000, **kwargs, ) -> None: @@ -198,6 +200,10 @@ class Language: if vocab is True: vectors_name = meta.get("vectors", {}).get("name") vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name) + if not create_vectors: + vectors_cfg = {"vectors": self._config["nlp"]["vectors"]} + create_vectors = registry.resolve(vectors_cfg)["vectors"] + vocab.vectors = create_vectors(vocab) else: if (self.lang and vocab.lang) and (self.lang != vocab.lang): raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) @@ -1765,6 +1771,10 @@ class Language: ).merge(config) if "nlp" not in config: raise ValueError(Errors.E985.format(config=config)) + # fill in [nlp.vectors] if not present (as a narrower alternative to + # auto-filling [nlp] from the default config) + if "vectors" not in config["nlp"]: + config["nlp"]["vectors"] = {"@vectors": "spacy.Vectors.v1"} config_lang = config["nlp"].get("lang") if config_lang is not None and config_lang != cls.lang: raise ValueError( @@ -1796,6 +1806,7 @@ class Language: filled["nlp"], validate=validate, schema=ConfigSchemaNlp ) create_tokenizer = resolved_nlp["tokenizer"] + create_vectors = resolved_nlp["vectors"] before_creation = resolved_nlp["before_creation"] after_creation = resolved_nlp["after_creation"] after_pipeline_creation = resolved_nlp["after_pipeline_creation"] @@ -1816,7 +1827,12 @@ class Language: # inside stuff like the spacy train function. If we loaded them here, # then we would load them twice at runtime: once when we make from config, # and then again when we load from disk. - nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, meta=meta) + nlp = lang_cls( + vocab=vocab, + create_tokenizer=create_tokenizer, + create_vectors=create_vectors, + meta=meta, + ) if after_creation is not None: nlp = after_creation(nlp) if not isinstance(nlp, cls): diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py index b75240c5d..1a1b0a0ff 100644 --- a/spacy/ml/staticvectors.py +++ b/spacy/ml/staticvectors.py @@ -9,7 +9,7 @@ from thinc.util import partial from ..attrs import ORTH from ..errors import Errors, Warnings from ..tokens import Doc -from ..vectors import Mode +from ..vectors import Mode, Vectors from ..vocab import Vocab @@ -48,11 +48,14 @@ def forward( key_attr: int = getattr(vocab.vectors, "attr", ORTH) keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs]) W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) - if vocab.vectors.mode == Mode.default: + if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default: V = model.ops.asarray(vocab.vectors.data) rows = vocab.vectors.find(keys=keys) V = model.ops.as_contig(V[rows]) - elif vocab.vectors.mode == Mode.floret: + elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret: + V = vocab.vectors.get_batch(keys) + V = model.ops.as_contig(V) + elif hasattr(vocab.vectors, "get_batch"): V = vocab.vectors.get_batch(keys) V = model.ops.as_contig(V) else: @@ -61,7 +64,7 @@ def forward( vectors_data = model.ops.gemm(V, W, trans2=True) except ValueError: raise RuntimeError(Errors.E896) - if vocab.vectors.mode == Mode.default: + if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default: # Convert negative indices to 0-vectors # TODO: more options for UNK tokens vectors_data[rows < 0] = 0 diff --git a/spacy/schemas.py b/spacy/schemas.py index 22c25e99d..3404687e1 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -397,6 +397,7 @@ class ConfigSchemaNlp(BaseModel): after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed") after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed") batch_size: Optional[int] = Field(..., title="Default batch size") + vectors: Callable = Field(..., title="Vectors implementation") # fmt: on class Config: diff --git a/spacy/util.py b/spacy/util.py index a2a033cbc..1689ac827 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -118,6 +118,7 @@ class registry(thinc.registry): augmenters = catalogue.create("spacy", "augmenters", entry_points=True) loggers = catalogue.create("spacy", "loggers", entry_points=True) scorers = catalogue.create("spacy", "scorers", entry_points=True) + vectors = catalogue.create("spacy", "vectors", entry_points=True) # These are factories registered via third-party packages and the # spacy_factories entry point. This registry only exists so we can easily # load them via the entry points. The "true" factories are added via the diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index a88f380f9..2817bcad4 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -1,3 +1,6 @@ +# cython: infer_types=True, profile=True, binding=True +from typing import Callable + from cython.operator cimport dereference as deref from libc.stdint cimport uint32_t, uint64_t from libcpp.set cimport set as cppset @@ -5,7 +8,8 @@ from murmurhash.mrmr cimport hash128_x64 import warnings from enum import Enum -from typing import cast +from pathlib import Path +from typing import TYPE_CHECKING, Union, cast import numpy import srsly @@ -21,6 +25,9 @@ from .attrs import IDS from .errors import Errors, Warnings from .strings import get_string_id +if TYPE_CHECKING: + from .vocab import Vocab # noqa: F401 # no-cython-lint + def unpickle_vectors(bytes_data): return Vectors().from_bytes(bytes_data) @@ -35,7 +42,71 @@ class Mode(str, Enum): return list(cls.__members__.keys()) -cdef class Vectors: +cdef class BaseVectors: + def __init__(self, *, strings=None): + # Make sure abstract BaseVectors is not instantiated. + if self.__class__ == BaseVectors: + raise TypeError( + Errors.E1046.format(cls_name=self.__class__.__name__) + ) + + def __getitem__(self, key): + raise NotImplementedError + + def __contains__(self, key): + raise NotImplementedError + + def is_full(self): + raise NotImplementedError + + def get_batch(self, keys): + raise NotImplementedError + + @property + def shape(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + @property + def vectors_length(self): + raise NotImplementedError + + @property + def size(self): + raise NotImplementedError + + def add(self, key, *, vector=None): + raise NotImplementedError + + def to_ops(self, ops: Ops): + pass + + # add dummy methods for to_bytes, from_bytes, to_disk and from_disk to + # allow serialization + def to_bytes(self, **kwargs): + return b"" + + def from_bytes(self, data: bytes, **kwargs): + return self + + def to_disk(self, path: Union[str, Path], **kwargs): + return None + + def from_disk(self, path: Union[str, Path], **kwargs): + return self + + +@util.registry.vectors("spacy.Vectors.v1") +def create_mode_vectors() -> Callable[["Vocab"], BaseVectors]: + def vectors_factory(vocab: "Vocab") -> BaseVectors: + return Vectors(strings=vocab.strings) + + return vectors_factory + + +cdef class Vectors(BaseVectors): """Store, save and load word vectors. Vectors data is kept in the vectors.data attribute, which should be an diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index d1edc8533..48e8fcb90 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -94,8 +94,9 @@ cdef class Vocab: return self._vectors def __set__(self, vectors): - for s in vectors.strings: - self.strings.add(s) + if hasattr(vectors, "strings"): + for s in vectors.strings: + self.strings.add(s) self._vectors = vectors self._vectors.strings = self.strings @@ -193,7 +194,7 @@ cdef class Vocab: lex = mem.alloc(1, sizeof(LexemeC)) lex.orth = self.strings.add(string) lex.length = len(string) - if self.vectors is not None: + if self.vectors is not None and hasattr(self.vectors, "key2row"): lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK) else: lex.id = OOV_RANK @@ -289,12 +290,17 @@ cdef class Vocab: @property def vectors_length(self): - return self.vectors.shape[1] + if hasattr(self.vectors, "shape"): + return self.vectors.shape[1] + else: + return -1 def reset_vectors(self, *, width=None, shape=None): """Drop the current vector table. Because all vectors must be the same width, you have to call this to change the size of the vectors. """ + if not isinstance(self.vectors, Vectors): + raise ValueError(Errors.E849.format(method="reset_vectors", vectors_type=type(self.vectors))) if width is not None and shape is not None: raise ValueError(Errors.E065.format(width=width, shape=shape)) elif shape is not None: @@ -304,6 +310,8 @@ cdef class Vocab: self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width)) def deduplicate_vectors(self): + if not isinstance(self.vectors, Vectors): + raise ValueError(Errors.E849.format(method="deduplicate_vectors", vectors_type=type(self.vectors))) if self.vectors.mode != VectorsMode.default: raise ValueError(Errors.E858.format( mode=self.vectors.mode, @@ -357,6 +365,8 @@ cdef class Vocab: DOCS: https://spacy.io/api/vocab#prune_vectors """ + if not isinstance(self.vectors, Vectors): + raise ValueError(Errors.E849.format(method="prune_vectors", vectors_type=type(self.vectors))) if self.vectors.mode != VectorsMode.default: raise ValueError(Errors.E858.format( mode=self.vectors.mode, diff --git a/website/docs/api/basevectors.mdx b/website/docs/api/basevectors.mdx new file mode 100644 index 000000000..993b9a33e --- /dev/null +++ b/website/docs/api/basevectors.mdx @@ -0,0 +1,143 @@ +--- +title: BaseVectors +teaser: Abstract class for word vectors +tag: class +source: spacy/vectors.pyx +version: 3.7 +--- + +`BaseVectors` is an abstract class to support the development of custom vectors +implementations. + +For use in training with [`StaticVectors`](/api/architectures#staticvectors), +`get_batch` must be implemented. For improved performance, use efficient +batching in `get_batch` and implement `to_ops` to copy the vector data to the +current device. See an example custom implementation for +[BPEmb subword embeddings](/usage/embeddings-transformers#custom-vectors). + +## BaseVectors.\_\_init\_\_ {id="init",tag="method"} + +Create a new vector store. + +| Name | Description | +| -------------- | --------------------------------------------------------------------------------------------------------------------- | +| _keyword-only_ | | +| `strings` | The string store. A new string store is created if one is not provided. Defaults to `None`. ~~Optional[StringStore]~~ | + +## BaseVectors.\_\_getitem\_\_ {id="getitem",tag="method"} + +Get a vector by key. If the key is not found in the table, a `KeyError` should +be raised. + +| Name | Description | +| ----------- | ---------------------------------------------------------------- | +| `key` | The key to get the vector for. ~~Union[int, str]~~ | +| **RETURNS** | The vector for the key. ~~numpy.ndarray[ndim=1, dtype=float32]~~ | + +## BaseVectors.\_\_len\_\_ {id="len",tag="method"} + +Return the number of vectors in the table. + +| Name | Description | +| ----------- | ------------------------------------------- | +| **RETURNS** | The number of vectors in the table. ~~int~~ | + +## BaseVectors.\_\_contains\_\_ {id="contains",tag="method"} + +Check whether there is a vector entry for the given key. + +| Name | Description | +| ----------- | -------------------------------------------- | +| `key` | The key to check. ~~int~~ | +| **RETURNS** | Whether the key has a vector entry. ~~bool~~ | + +## BaseVectors.add {id="add",tag="method"} + +Add a key to the table, if possible. If no keys can be added, return `-1`. + +| Name | Description | +| ----------- | ----------------------------------------------------------------------------------- | +| `key` | The key to add. ~~Union[str, int]~~ | +| **RETURNS** | The row the vector was added to, or `-1` if the operation is not supported. ~~int~~ | + +## BaseVectors.shape {id="shape",tag="property"} + +Get `(rows, dims)` tuples of number of rows and number of dimensions in the +vector table. + +| Name | Description | +| ----------- | ------------------------------------------ | +| **RETURNS** | A `(rows, dims)` pair. ~~Tuple[int, int]~~ | + +## BaseVectors.size {id="size",tag="property"} + +The vector size, i.e. `rows * dims`. + +| Name | Description | +| ----------- | ------------------------ | +| **RETURNS** | The vector size. ~~int~~ | + +## BaseVectors.is_full {id="is_full",tag="property"} + +Whether the vectors table is full and no slots are available for new keys. + +| Name | Description | +| ----------- | ------------------------------------------- | +| **RETURNS** | Whether the vectors table is full. ~~bool~~ | + +## BaseVectors.get_batch {id="get_batch",tag="method",version="3.2"} + +Get the vectors for the provided keys efficiently as a batch. Required to use +the vectors with [`StaticVectors`](/api/architectures#StaticVectors) for +training. + +| Name | Description | +| ------ | --------------------------------------- | +| `keys` | The keys. ~~Iterable[Union[int, str]]~~ | + +## BaseVectors.to_ops {id="to_ops",tag="method"} + +Dummy method. Implement this to change the embedding matrix to use different +Thinc ops. + +| Name | Description | +| ----- | -------------------------------------------------------- | +| `ops` | The Thinc ops to switch the embedding matrix to. ~~Ops~~ | + +## BaseVectors.to_disk {id="to_disk",tag="method"} + +Dummy method to allow serialization. Implement to save vector data with the +pipeline. + +| Name | Description | +| ------ | ------------------------------------------------------------------------------------------------------------------------------------------ | +| `path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ | + +## BaseVectors.from_disk {id="from_disk",tag="method"} + +Dummy method to allow serialization. Implement to load vector data from a saved +pipeline. + +| Name | Description | +| ----------- | ----------------------------------------------------------------------------------------------- | +| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ | +| **RETURNS** | The modified vectors object. ~~BaseVectors~~ | + +## BaseVectors.to_bytes {id="to_bytes",tag="method"} + +Dummy method to allow serialization. Implement to serialize vector data to a +binary string. + +| Name | Description | +| ----------- | ---------------------------------------------------- | +| **RETURNS** | The serialized form of the vectors object. ~~bytes~~ | + +## BaseVectors.from_bytes {id="from_bytes",tag="method"} + +Dummy method to allow serialization. Implement to load vector data from a binary +string. + +| Name | Description | +| ----------- | ----------------------------------- | +| `data` | The data to load from. ~~bytes~~ | +| **RETURNS** | The vectors object. ~~BaseVectors~~ | diff --git a/website/docs/api/vectors.mdx b/website/docs/api/vectors.mdx index fa4cd0c7a..0e92eb12b 100644 --- a/website/docs/api/vectors.mdx +++ b/website/docs/api/vectors.mdx @@ -297,10 +297,9 @@ The vector size, i.e. `rows * dims`. ## Vectors.is_full {id="is_full",tag="property"} -Whether the vectors table is full and has no slots are available for new keys. -If a table is full, it can be resized using -[`Vectors.resize`](/api/vectors#resize). In `floret` mode, the table is always -full and cannot be resized. +Whether the vectors table is full and no slots are available for new keys. If a +table is full, it can be resized using [`Vectors.resize`](/api/vectors#resize). +In `floret` mode, the table is always full and cannot be resized. > #### Example > @@ -441,7 +440,7 @@ Load state from a binary string. > #### Example > > ```python -> fron spacy.vectors import Vectors +> from spacy.vectors import Vectors > vectors_bytes = vectors.to_bytes() > new_vectors = Vectors(StringStore()) > new_vectors.from_bytes(vectors_bytes) diff --git a/website/docs/usage/embeddings-transformers.mdx b/website/docs/usage/embeddings-transformers.mdx index 5f1e5b817..2bd2856b6 100644 --- a/website/docs/usage/embeddings-transformers.mdx +++ b/website/docs/usage/embeddings-transformers.mdx @@ -632,6 +632,165 @@ def MyCustomVectors( ) ``` +#### Creating a custom vectors implementation {id="custom-vectors",version="3.7"} + +You can specify a custom registered vectors class under `[nlp.vectors]` in order +to use static vectors in formats other than the ones supported by +[`Vectors`](/api/vectors). Extend the abstract [`BaseVectors`](/api/basevectors) +class to implement your custom vectors. + +As an example, the following `BPEmbVectors` class implements support for +[BPEmb subword embeddings](https://bpemb.h-its.org/): + +```python +# requires: pip install bpemb +import warnings +from pathlib import Path +from typing import Callable, Optional, cast + +from bpemb import BPEmb +from thinc.api import Ops, get_current_ops +from thinc.backends import get_array_ops +from thinc.types import Floats2d + +from spacy.strings import StringStore +from spacy.util import registry +from spacy.vectors import BaseVectors +from spacy.vocab import Vocab + + +class BPEmbVectors(BaseVectors): + def __init__( + self, + *, + strings: Optional[StringStore] = None, + lang: Optional[str] = None, + vs: Optional[int] = None, + dim: Optional[int] = None, + cache_dir: Optional[Path] = None, + encode_extra_options: Optional[str] = None, + model_file: Optional[Path] = None, + emb_file: Optional[Path] = None, + ): + kwargs = {} + if lang is not None: + kwargs["lang"] = lang + if vs is not None: + kwargs["vs"] = vs + if dim is not None: + kwargs["dim"] = dim + if cache_dir is not None: + kwargs["cache_dir"] = cache_dir + if encode_extra_options is not None: + kwargs["encode_extra_options"] = encode_extra_options + if model_file is not None: + kwargs["model_file"] = model_file + if emb_file is not None: + kwargs["emb_file"] = emb_file + self.bpemb = BPEmb(**kwargs) + self.strings = strings + self.name = repr(self.bpemb) + self.n_keys = -1 + self.mode = "BPEmb" + self.to_ops(get_current_ops()) + + def __contains__(self, key): + return True + + def is_full(self): + return True + + def add(self, key, *, vector=None, row=None): + warnings.warn( + ( + "Skipping BPEmbVectors.add: the bpemb vector table cannot be " + "modified. Vectors are calculated from bytepieces." + ) + ) + return -1 + + def __getitem__(self, key): + return self.get_batch([key])[0] + + def get_batch(self, keys): + keys = [self.strings.as_string(key) for key in keys] + bp_ids = self.bpemb.encode_ids(keys) + ops = get_array_ops(self.bpemb.emb.vectors) + indices = ops.asarray(ops.xp.hstack(bp_ids), dtype="int32") + lengths = ops.asarray([len(x) for x in bp_ids], dtype="int32") + vecs = ops.reduce_mean(cast(Floats2d, self.bpemb.emb.vectors[indices]), lengths) + return vecs + + @property + def shape(self): + return self.bpemb.vectors.shape + + def __len__(self): + return self.shape[0] + + @property + def vectors_length(self): + return self.shape[1] + + @property + def size(self): + return self.bpemb.vectors.size + + def to_ops(self, ops: Ops): + self.bpemb.emb.vectors = ops.asarray(self.bpemb.emb.vectors) + + +@registry.vectors("BPEmbVectors.v1") +def create_bpemb_vectors( + lang: Optional[str] = "multi", + vs: Optional[int] = None, + dim: Optional[int] = None, + cache_dir: Optional[Path] = None, + encode_extra_options: Optional[str] = None, + model_file: Optional[Path] = None, + emb_file: Optional[Path] = None, +) -> Callable[[Vocab], BPEmbVectors]: + def bpemb_vectors_factory(vocab: Vocab) -> BPEmbVectors: + return BPEmbVectors( + strings=vocab.strings, + lang=lang, + vs=vs, + dim=dim, + cache_dir=cache_dir, + encode_extra_options=encode_extra_options, + model_file=model_file, + emb_file=emb_file, + ) + + return bpemb_vectors_factory +``` + + + +Note that the serialization methods are not implemented, so the embeddings are +loaded from your local cache or downloaded by `BPEmb` each time the pipeline is +loaded. + + + +To use this in your pipeline, specify this registered function under +`[nlp.vectors]` in your config: + +```ini +[nlp.vectors] +@vectors = "BPEmbVectors.v1" +lang = "en" +``` + +Or specify it when creating a blank pipeline: + +```python +nlp = spacy.blank("en", config={"nlp.vectors": {"@vectors": "BPEmbVectors.v1", "lang": "en"}}) +``` + +Remember to include this code with `--code` when using +[`spacy train`](/api/cli#train) and [`spacy package`](/api/cli#package). + ## Pretraining {id="pretraining"} The [`spacy pretrain`](/api/cli#pretrain) command lets you initialize your diff --git a/website/meta/sidebars.json b/website/meta/sidebars.json index 04102095f..d2f73d83a 100644 --- a/website/meta/sidebars.json +++ b/website/meta/sidebars.json @@ -131,6 +131,7 @@ "label": "Other", "items": [ { "text": "Attributes", "url": "/api/attributes" }, + { "text": "BaseVectors", "url": "/api/basevectors" }, { "text": "Corpus", "url": "/api/corpus" }, { "text": "InMemoryLookupKB", "url": "/api/inmemorylookupkb" }, { "text": "KnowledgeBase", "url": "/api/kb" },