mirror of https://github.com/explosion/spaCy.git
Support having word vectors data on GPU
This commit is contained in:
parent
95bca20c17
commit
e0a2aa9289
|
@ -6,6 +6,8 @@ import msgpack
|
||||||
import msgpack_numpy
|
import msgpack_numpy
|
||||||
msgpack_numpy.patch()
|
msgpack_numpy.patch()
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
|
from thinc.neural.util import get_array_module
|
||||||
|
from thinc.neural._classes.model import Model
|
||||||
|
|
||||||
from .typedefs cimport attr_t
|
from .typedefs cimport attr_t
|
||||||
from .strings cimport StringStore
|
from .strings cimport StringStore
|
||||||
|
@ -31,7 +33,7 @@ cdef class Vectors:
|
||||||
self.i = 0
|
self.i = 0
|
||||||
self.data = data
|
self.data = data
|
||||||
self.key2row = {}
|
self.key2row = {}
|
||||||
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
|
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Vectors, (self.strings, self.data))
|
return (Vectors, (self.strings, self.data))
|
||||||
|
@ -118,9 +120,14 @@ cdef class Vectors:
|
||||||
self.data
|
self.data
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
|
xp = get_array_module(self.data)
|
||||||
|
if xp is numpy:
|
||||||
|
save_array = lambda arr, file_: xp.save(file_, arr, allow_pickle=False)
|
||||||
|
else:
|
||||||
|
save_array = lambda arr, file_: xp.save(file_, arr)
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
|
('vectors', lambda p: save_array(self.data, p.open('wb'))),
|
||||||
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
|
('keys', lambda p: xp.save(p.open('wb'), self.keys))
|
||||||
))
|
))
|
||||||
return util.to_disk(path, serializers, exclude)
|
return util.to_disk(path, serializers, exclude)
|
||||||
|
|
||||||
|
@ -133,8 +140,9 @@ cdef class Vectors:
|
||||||
self.key2row[key] = i
|
self.key2row[key] = i
|
||||||
|
|
||||||
def load_vectors(path):
|
def load_vectors(path):
|
||||||
|
xp = Model.ops.xp
|
||||||
if path.exists():
|
if path.exists():
|
||||||
self.data = numpy.load(path)
|
self.data = xp.load(path)
|
||||||
|
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('keys', load_keys),
|
('keys', load_keys),
|
||||||
|
|
Loading…
Reference in New Issue