Fix fine-tuning

This commit is contained in:
Matthew Honnibal 2017-08-20 18:17:35 +02:00
parent 3fe0d76e6d
commit 8a59718fd6
1 changed files with 12 additions and 11 deletions

View File

@ -359,8 +359,6 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.):
def backward(d_output, sgd=None): def backward(d_output, sgd=None):
return (tokens, d_output) return (tokens, d_output)
return vectors, backward return vectors, backward
def fine_tune(embedding, combine=None): def fine_tune(embedding, combine=None):
if combine is not None: if combine is not None:
raise NotImplementedError( raise NotImplementedError(
@ -372,22 +370,25 @@ def fine_tune(embedding, combine=None):
vecs, bp_vecs = embedding.begin_update(docs, drop=drop) vecs, bp_vecs = embedding.begin_update(docs, drop=drop)
flat_tokvecs = embedding.ops.flatten(tokvecs) flat_tokvecs = embedding.ops.flatten(tokvecs)
flat_vecs = embedding.ops.flatten(vecs) flat_vecs = embedding.ops.flatten(vecs)
alpha = model.mix
minus = 1-model.mix
output = embedding.ops.unflatten( output = embedding.ops.unflatten(
(model.mix[0] * flat_vecs + model.mix[1] * flat_tokvecs), (alpha * flat_tokvecs + minus * flat_vecs), lengths)
lengths)
def fine_tune_bwd(d_output, sgd=None): def fine_tune_bwd(d_output, sgd=None):
bp_vecs(d_output, sgd=sgd)
flat_grad = model.ops.flatten(d_output) flat_grad = model.ops.flatten(d_output)
model.d_mix[1] += flat_tokvecs.dot(flat_grad.T).sum() model.d_mix += flat_tokvecs.dot(flat_grad.T).sum()
model.d_mix[0] += flat_vecs.dot(flat_grad.T).sum() model.d_mix += 1-flat_vecs.dot(flat_grad.T).sum()
if sgd is not None:
sgd(model._mem.weights, model._mem.gradient, key=model.id) bp_vecs([d_o * minus for d_o in d_output], sgd=sgd)
d_output = [d_o * alpha for d_o in d_output]
sgd(model._mem.weights, model._mem.gradient, key=model.id)
model.mix = model.ops.xp.minimum(model.mix, 1.0)
return d_output return d_output
return output, fine_tune_bwd return output, fine_tune_bwd
model = wrap(fine_tune_fwd, embedding) model = wrap(fine_tune_fwd, embedding)
model.mix = model._mem.add((model.id, 'mix'), (2,)) model.mix = model._mem.add((model.id, 'mix'), (1,))
model.mix.fill(1.) model.mix.fill(0.0)
model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix')) model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix'))
return model return model