From 5d837c37762cb06a230906be80225e0e421c6cb2 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 7 Aug 2017 06:32:59 -0500 Subject: [PATCH] Add mix weights on fine_tune --- spacy/_ml.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index f7ab9b259..d28f48c42 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -356,17 +356,24 @@ def fine_tune(embedding, combine=None): lengths = model.ops.asarray([len(doc) for doc in docs], dtype='i') vecs, bp_vecs = embedding.begin_update(docs, drop=drop) - + flat_tokvecs = embedding.ops.flatten(tokvecs) + flat_vecs = embedding.ops.flatten(vecs) output = embedding.ops.unflatten( - embedding.ops.flatten(tokvecs) - + embedding.ops.flatten(vecs), + (model.mix[0] * flat_vecs + model.mix[1] * flat_tokvecs), lengths) def fine_tune_bwd(d_output, sgd=None): bp_vecs(d_output, sgd=sgd) + flat_grad = model.ops.flatten(d_output) + model.d_mix[1] += flat_tokvecs.dot(flat_grad.T).sum() + model.d_mix[0] += flat_vecs.dot(flat_grad.T).sum() + sgd(model._mem.weights, model._mem.gradient, key=model.id) return d_output return output, fine_tune_bwd model = wrap(fine_tune_fwd, embedding) + model.mix = model._mem.add((model.id, 'mix'), (2,)) + model.mix.fill(1.) + model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix')) return model