From e026b29ea92c22de3ff11a56d6648ff404138c80 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 30 Oct 2017 17:59:43 +0100 Subject: [PATCH] Add prune_vectors method to Vocab --- spacy/vocab.pyx | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 160f0d5bd..ff6c5b844 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -5,6 +5,7 @@ import numpy import dill from collections import OrderedDict +from thinc.neural.util import get_array_module from .lexeme cimport EMPTY_LEXEME from .lexeme cimport Lexeme from .strings cimport hash_string @@ -247,6 +248,44 @@ cdef class Vocab: width = self.vectors.data.shape[1] self.vectors = Vectors(self.strings, width=width) + def prune_vectors(self, nr_row, batch_size=1024): + """Reduce the current vector table to `nr_row` unique entries. Words + mapped to the discarded vectors will be remapped to the closest vector + among those remaining. + + For example, suppose the original table had vectors for the words: + ['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to, + two rows, we would discard the vectors for 'feline' and 'reclined'. + These words would then be remapped to the closest remaining vector + -- so "feline" would have the same vector as "cat", and "reclined" + would have the same vector as "sat". + + The similarities are judged by cosine. The original vectors may + be large, so the cosines are calculated in minibatches, to reduce + memory usage. + """ + xp = get_array_module(self.vectors.data) + # Work in batches, to avoid memory problems. + keep = self.vectors.data[:nr_row] + toss = self.vectors.data[nr_row:] + # Normalize the vectors, so cosine similarity is just dot product. + # Note we can't modify the ones we're keeping in-place... + keep = keep / (xp.linalg.norm(keep)+1e-8) + keep = xp.ascontiguousarray(keep.T) + neighbours = xp.zeros((toss.shape[0],), dtype='i') + for i in range(0, toss.shape[0], batch_size): + batch = toss[i : i+batch_size] + batch /= xp.linalg.norm(batch)+1e-8 + neighbours[i:i+batch_size] = xp.dot(batch, keep).argmax(axis=1) + for lex in self: + # If we're losing the vector for this word, map it to the nearest + # vector we're keeping. + if lex.rank >= nr_row: + lex.rank = neighbours[lex.rank-nr_row] + self.vectors.add(lex.orth, row=lex.rank) + # Make copy, to encourage the original table to be garbage collected. + self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row]) + def get_vector(self, orth): """Retrieve a vector for a word in the vocabulary. Words can be looked up by string or int ID. If no vectors data is loaded, ValueError is