mirror of https://github.com/explosion/spaCy.git
Add mix weights on fine_tune
This commit is contained in:
parent
42bd26f6f3
commit
5d837c3776
13
spacy/_ml.py
13
spacy/_ml.py
|
@ -356,17 +356,24 @@ def fine_tune(embedding, combine=None):
|
||||||
lengths = model.ops.asarray([len(doc) for doc in docs], dtype='i')
|
lengths = model.ops.asarray([len(doc) for doc in docs], dtype='i')
|
||||||
|
|
||||||
vecs, bp_vecs = embedding.begin_update(docs, drop=drop)
|
vecs, bp_vecs = embedding.begin_update(docs, drop=drop)
|
||||||
|
flat_tokvecs = embedding.ops.flatten(tokvecs)
|
||||||
|
flat_vecs = embedding.ops.flatten(vecs)
|
||||||
output = embedding.ops.unflatten(
|
output = embedding.ops.unflatten(
|
||||||
embedding.ops.flatten(tokvecs)
|
(model.mix[0] * flat_vecs + model.mix[1] * flat_tokvecs),
|
||||||
+ embedding.ops.flatten(vecs),
|
|
||||||
lengths)
|
lengths)
|
||||||
|
|
||||||
def fine_tune_bwd(d_output, sgd=None):
|
def fine_tune_bwd(d_output, sgd=None):
|
||||||
bp_vecs(d_output, sgd=sgd)
|
bp_vecs(d_output, sgd=sgd)
|
||||||
|
flat_grad = model.ops.flatten(d_output)
|
||||||
|
model.d_mix[1] += flat_tokvecs.dot(flat_grad.T).sum()
|
||||||
|
model.d_mix[0] += flat_vecs.dot(flat_grad.T).sum()
|
||||||
|
sgd(model._mem.weights, model._mem.gradient, key=model.id)
|
||||||
return d_output
|
return d_output
|
||||||
return output, fine_tune_bwd
|
return output, fine_tune_bwd
|
||||||
model = wrap(fine_tune_fwd, embedding)
|
model = wrap(fine_tune_fwd, embedding)
|
||||||
|
model.mix = model._mem.add((model.id, 'mix'), (2,))
|
||||||
|
model.mix.fill(1.)
|
||||||
|
model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix'))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue