Use openblas.sgemm in parser

This commit is contained in:
Matthew Honnibal 2018-03-13 02:12:01 +01:00
parent d55620041b
commit 952c87409e
1 changed files with 3 additions and 3 deletions

View File

@ -144,8 +144,8 @@ class PrecomputableAffine(Model):
self.nF = nF self.nF = nF
def begin_update(self, X, drop=0.): def begin_update(self, X, drop=0.):
Yf = self.ops.xp.dot(X, Yf = self.ops.gemm(X,
self.W.reshape((self.nF*self.nO*self.nP, self.nI)).T) self.W.reshape((self.nF*self.nO*self.nP, self.nI)), trans2=True)
Yf = Yf.reshape((Yf.shape[0], self.nF, self.nO, self.nP)) Yf = Yf.reshape((Yf.shape[0], self.nF, self.nO, self.nP))
Yf = self._add_padding(Yf) Yf = self._add_padding(Yf)
@ -165,7 +165,7 @@ class PrecomputableAffine(Model):
# Reuse the buffer # Reuse the buffer
dWopfi = Wopfi; dWopfi.fill(0.) dWopfi = Wopfi; dWopfi.fill(0.)
self.ops.xp.dot(dY.T, Xf, out=dWopfi) self.ops.gemm(dY, Xf, out=dWopfi, trans1=True)
dWopfi = dWopfi.reshape((self.nO, self.nP, self.nF, self.nI)) dWopfi = dWopfi.reshape((self.nO, self.nP, self.nF, self.nI))
# (o, p, f, i) --> (f, o, p, i) # (o, p, f, i) --> (f, o, p, i)
self.d_W += dWopfi.transpose((2, 0, 1, 3)) self.d_W += dWopfi.transpose((2, 0, 1, 3))