diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py index 4c9e53563..ce2c7efff 100644 --- a/spacy/ml/staticvectors.py +++ b/spacy/ml/staticvectors.py @@ -37,15 +37,14 @@ def forward( if not len(docs): return _handle_empty(model.ops, model.get_dim("nO")) key_attr = model.attrs["key_attr"] - W = cast(Floats2d, model.get_param("W")) + W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) V = cast(Floats2d, docs[0].vocab.vectors.data) mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate")) - rows = model.ops.flatten( [doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs] ) output = Ragged( - model.ops.gemm(V[rows], W, trans2=True), + model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True), model.ops.asarray([len(doc) for doc in docs], dtype="i") ) if mask is not None: @@ -54,7 +53,14 @@ def forward( def backprop(d_output: Ragged) -> List[Doc]: if mask is not None: d_output.data *= mask - model.inc_grad("W", model.ops.gemm(d_output.data, V[rows], trans1=True)) + model.inc_grad( + "W", + model.ops.gemm( + d_output.data, + model.ops.as_contig(V[rows]), + trans1=True + ) + ) return [] return output, backprop