mirror of https://github.com/explosion/spaCy.git
Use openblas.sgemm in parser
This commit is contained in:
parent
d55620041b
commit
952c87409e
|
@ -144,8 +144,8 @@ class PrecomputableAffine(Model):
|
||||||
self.nF = nF
|
self.nF = nF
|
||||||
|
|
||||||
def begin_update(self, X, drop=0.):
|
def begin_update(self, X, drop=0.):
|
||||||
Yf = self.ops.xp.dot(X,
|
Yf = self.ops.gemm(X,
|
||||||
self.W.reshape((self.nF*self.nO*self.nP, self.nI)).T)
|
self.W.reshape((self.nF*self.nO*self.nP, self.nI)), trans2=True)
|
||||||
Yf = Yf.reshape((Yf.shape[0], self.nF, self.nO, self.nP))
|
Yf = Yf.reshape((Yf.shape[0], self.nF, self.nO, self.nP))
|
||||||
Yf = self._add_padding(Yf)
|
Yf = self._add_padding(Yf)
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ class PrecomputableAffine(Model):
|
||||||
|
|
||||||
# Reuse the buffer
|
# Reuse the buffer
|
||||||
dWopfi = Wopfi; dWopfi.fill(0.)
|
dWopfi = Wopfi; dWopfi.fill(0.)
|
||||||
self.ops.xp.dot(dY.T, Xf, out=dWopfi)
|
self.ops.gemm(dY, Xf, out=dWopfi, trans1=True)
|
||||||
dWopfi = dWopfi.reshape((self.nO, self.nP, self.nF, self.nI))
|
dWopfi = dWopfi.reshape((self.nO, self.nP, self.nF, self.nI))
|
||||||
# (o, p, f, i) --> (f, o, p, i)
|
# (o, p, f, i) --> (f, o, p, i)
|
||||||
self.d_W += dWopfi.transpose((2, 0, 1, 3))
|
self.d_W += dWopfi.transpose((2, 0, 1, 3))
|
||||||
|
|
Loading…
Reference in New Issue