mirror of https://github.com/explosion/spaCy.git
Fix Embed and HistoryFeatures
This commit is contained in:
parent
246612cb53
commit
92066b04d6
|
@ -231,6 +231,8 @@ class Embed(Model):
|
||||||
|
|
||||||
def __init__(self, nO, nV=None, **kwargs):
|
def __init__(self, nO, nV=None, **kwargs):
|
||||||
Model.__init__(self, **kwargs)
|
Model.__init__(self, **kwargs)
|
||||||
|
if 'name' in kwargs:
|
||||||
|
self.name = kwargs['name']
|
||||||
self.column = kwargs.get('column', 0)
|
self.column = kwargs.get('column', 0)
|
||||||
self.nO = nO
|
self.nO = nO
|
||||||
self.nV = nV
|
self.nV = nV
|
||||||
|
@ -238,12 +240,12 @@ class Embed(Model):
|
||||||
def predict(self, ids):
|
def predict(self, ids):
|
||||||
if ids.ndim == 2:
|
if ids.ndim == 2:
|
||||||
ids = ids[:, self.column]
|
ids = ids[:, self.column]
|
||||||
return self._embed(ids)
|
return self.ops.xp.ascontiguousarray(self.vectors[ids])
|
||||||
|
|
||||||
def begin_update(self, ids, drop=0.):
|
def begin_update(self, ids, drop=0.):
|
||||||
if ids.ndim == 2:
|
if ids.ndim == 2:
|
||||||
ids = ids[:, self.column]
|
ids = ids[:, self.column]
|
||||||
vectors = self.vectors[ids]
|
vectors = self.ops.xp.ascontiguousarray(self.vectors[ids])
|
||||||
def backprop_embed(d_vectors, sgd=None):
|
def backprop_embed(d_vectors, sgd=None):
|
||||||
n_vectors = d_vectors.shape[0]
|
n_vectors = d_vectors.shape[0]
|
||||||
self.ops.scatter_add(self.d_vectors, ids, d_vectors)
|
self.ops.scatter_add(self.d_vectors, ids, d_vectors)
|
||||||
|
@ -255,7 +257,8 @@ class Embed(Model):
|
||||||
|
|
||||||
def HistoryFeatures(nr_class, hist_size=8, nr_dim=8):
|
def HistoryFeatures(nr_class, hist_size=8, nr_dim=8):
|
||||||
'''Wrap a model, adding features representing action history.'''
|
'''Wrap a model, adding features representing action history.'''
|
||||||
embed_tables = [Embed(nr_dim, nr_class, column=i) for i in range(hist_size)]
|
embed_tables = [Embed(nr_dim, nr_class, column=i, name='embed%d')
|
||||||
|
for i in range(hist_size)]
|
||||||
embed = concatenate(*embed_tables)
|
embed = concatenate(*embed_tables)
|
||||||
ops = embed.ops
|
ops = embed.ops
|
||||||
def add_history_fwd(vectors_hists, drop=0.):
|
def add_history_fwd(vectors_hists, drop=0.):
|
||||||
|
|
Loading…
Reference in New Issue