Support having word vectors data on GPU

This commit is contained in:
Matthew Honnibal 2017-09-16 12:45:09 -05:00
parent 95bca20c17
commit e0a2aa9289
1 changed files with 12 additions and 4 deletions

View File

@ -6,6 +6,8 @@ import msgpack
import msgpack_numpy import msgpack_numpy
msgpack_numpy.patch() msgpack_numpy.patch()
cimport numpy as np cimport numpy as np
from thinc.neural.util import get_array_module
from thinc.neural._classes.model import Model
from .typedefs cimport attr_t from .typedefs cimport attr_t
from .strings cimport StringStore from .strings cimport StringStore
@ -31,7 +33,7 @@ cdef class Vectors:
self.i = 0 self.i = 0
self.data = data self.data = data
self.key2row = {} self.key2row = {}
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64') self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
def __reduce__(self): def __reduce__(self):
return (Vectors, (self.strings, self.data)) return (Vectors, (self.strings, self.data))
@ -118,9 +120,14 @@ cdef class Vectors:
self.data self.data
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
xp = get_array_module(self.data)
if xp is numpy:
save_array = lambda arr, file_: xp.save(file_, arr, allow_pickle=False)
else:
save_array = lambda arr, file_: xp.save(file_, arr)
serializers = OrderedDict(( serializers = OrderedDict((
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)), ('vectors', lambda p: save_array(self.data, p.open('wb'))),
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)), ('keys', lambda p: xp.save(p.open('wb'), self.keys))
)) ))
return util.to_disk(path, serializers, exclude) return util.to_disk(path, serializers, exclude)
@ -133,8 +140,9 @@ cdef class Vectors:
self.key2row[key] = i self.key2row[key] = i
def load_vectors(path): def load_vectors(path):
xp = Model.ops.xp
if path.exists(): if path.exists():
self.data = numpy.load(path) self.data = xp.load(path)
serializers = OrderedDict(( serializers = OrderedDict((
('keys', load_keys), ('keys', load_keys),