diff --git a/examples/training/rehearsal.py b/examples/training/rehearsal.py new file mode 100644 index 000000000..8ec410d0b --- /dev/null +++ b/examples/training/rehearsal.py @@ -0,0 +1,77 @@ +"""Prevent catastrophic forgetting with rehearsal updates.""" +import plac +import random +import srsly +import spacy +from spacy.gold import GoldParse +from spacy.util import minibatch + + +LABEL = "ANIMAL" +TRAIN_DATA = [ + ( + "Horses are too tall and they pretend to care about your feelings", + {"entities": [(0, 6, "ANIMAL")]}, + ), + ("Do they bite?", {"entities": []}), + ( + "horses are too tall and they pretend to care about your feelings", + {"entities": [(0, 6, "ANIMAL")]}, + ), + ("horses pretend to care about your feelings", {"entities": [(0, 6, "ANIMAL")]}), + ( + "they pretend to care about your feelings, those horses", + {"entities": [(48, 54, "ANIMAL")]}, + ), + ("horses?", {"entities": [(0, 6, "ANIMAL")]}), +] + + +def read_raw_data(nlp, jsonl_loc): + for json_obj in srsly.read_jsonl(jsonl_loc): + if json_obj["text"].strip(): + doc = nlp.make_doc(json_obj["text"]) + yield doc + + +def read_gold_data(nlp, gold_loc): + docs = [] + golds = [] + for json_obj in srsly.read_jsonl(gold_loc): + doc = nlp.make_doc(json_obj["text"]) + ents = [(ent["start"], ent["end"], ent["label"]) for ent in json_obj["spans"]] + gold = GoldParse(doc, entities=ents) + docs.append(doc) + golds.append(gold) + return list(zip(docs, golds)) + + +def main(model_name, unlabelled_loc): + n_iter = 10 + dropout = 0.2 + batch_size = 4 + nlp = spacy.load(model_name) + nlp.get_pipe("ner").add_label(LABEL) + raw_docs = list(read_raw_data(nlp, unlabelled_loc)) + optimizer = nlp.resume_training() + + # get names of other pipes to disable them during training + other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"] + with nlp.disable_pipes(*other_pipes): + for itn in range(n_iter): + random.shuffle(TRAIN_DATA) + random.shuffle(raw_docs) + losses = {} + r_losses = {} + # batch up the examples using spaCy's minibatch + raw_batches = minibatch(raw_docs, size=batch_size) + for doc, gold in TRAIN_DATA: + nlp.update([doc], [gold], sgd=optimizer, drop=dropout, losses=losses) + raw_batch = list(next(raw_batches)) + nlp.rehearse(raw_batch, sgd=optimizer, losses=r_losses) + print("Losses", losses) + print("R. Losses", r_losses) + + +if __name__ == "__main__": + plac.call(main) diff --git a/spacy/_ml.py b/spacy/_ml.py index 3a25aab1c..511babf3c 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -586,16 +586,8 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True, if exclusive_classes: output_layer = Softmax(nr_class, tok2vec.nO) else: - output_layer = ( - zero_init(Affine(nr_class, tok2vec.nO)) - >> logistic - ) - model = ( - tok2vec - >> flatten_add_lengths - >> Pooling(mean_pool) - >> output_layer - ) + output_layer = zero_init(Affine(nr_class, tok2vec.nO)) >> logistic + model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer model.tok2vec = chain(tok2vec, flatten) model.nO = nr_class return model @@ -637,3 +629,79 @@ def concatenate_lists(*layers, **kwargs): # pragma: no cover model = wrap(concatenate_lists_fwd, concat) return model + + +def masked_language_model(vocab, model, mask_prob=0.15): + """Convert a model into a BERT-style masked language model""" + + random_words = _RandomWords(vocab) + + def mlm_forward(docs, drop=0.0): + mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob) + mask = model.ops.asarray(mask).reshape((mask.shape[0], 1)) + output, backprop = model.begin_update(docs, drop=drop) + + def mlm_backward(d_output, sgd=None): + d_output *= 1 - mask + return backprop(d_output, sgd=sgd) + + return output, mlm_backward + + return wrap(mlm_forward, model) + + +class _RandomWords(object): + def __init__(self, vocab): + self.words = [lex.text for lex in vocab if lex.prob != 0.0] + self.probs = [lex.prob for lex in vocab if lex.prob != 0.0] + self.words = self.words[:10000] + self.probs = self.probs[:10000] + self.probs = numpy.exp(numpy.array(self.probs, dtype="f")) + self.probs /= self.probs.sum() + self._cache = [] + + def next(self): + if not self._cache: + self._cache.extend( + numpy.random.choice(len(self.words), 10000, p=self.probs) + ) + index = self._cache.pop() + return self.words[index] + + +def _apply_mask(docs, random_words, mask_prob=0.15): + # This needs to be here to avoid circular imports + from .tokens.doc import Doc + + N = sum(len(doc) for doc in docs) + mask = numpy.random.uniform(0.0, 1.0, (N,)) + mask = mask >= mask_prob + i = 0 + masked_docs = [] + for doc in docs: + words = [] + for token in doc: + if not mask[i]: + word = _replace_word(token.text, random_words) + else: + word = token.text + words.append(word) + i += 1 + spaces = [bool(w.whitespace_) for w in doc] + # NB: If you change this implementation to instead modify + # the docs in place, take care that the IDs reflect the original + # words. Currently we use the original docs to make the vectors + # for the target, so we don't lose the original tokens. But if + # you modified the docs in place here, you would. + masked_docs.append(Doc(doc.vocab, words=words, spaces=spaces)) + return mask, masked_docs + + +def _replace_word(word, random_words, mask="[MASK]"): + roll = numpy.random.random() + if roll < 0.8: + return mask + elif roll < 0.9: + return random_words.next() + else: + return word diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 70cab05c2..598739246 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -17,6 +17,7 @@ import srsly from ..tokens import Doc from ..attrs import ID, HEAD from .._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer +from .._ml import masked_language_model from .. import util @@ -212,79 +213,6 @@ def create_pretraining_model(nlp, tok2vec): return model -def masked_language_model(vocab, model, mask_prob=0.15): - """Convert a model into a BERT-style masked language model""" - - random_words = RandomWords(vocab) - - def mlm_forward(docs, drop=0.0): - mask, docs = apply_mask(docs, random_words, mask_prob=mask_prob) - mask = model.ops.asarray(mask).reshape((mask.shape[0], 1)) - output, backprop = model.begin_update(docs, drop=drop) - - def mlm_backward(d_output, sgd=None): - d_output *= 1 - mask - return backprop(d_output, sgd=sgd) - - return output, mlm_backward - - return wrap(mlm_forward, model) - - -def apply_mask(docs, random_words, mask_prob=0.15): - N = sum(len(doc) for doc in docs) - mask = numpy.random.uniform(0.0, 1.0, (N,)) - mask = mask >= mask_prob - i = 0 - masked_docs = [] - for doc in docs: - words = [] - for token in doc: - if not mask[i]: - word = replace_word(token.text, random_words) - else: - word = token.text - words.append(word) - i += 1 - spaces = [bool(w.whitespace_) for w in doc] - # NB: If you change this implementation to instead modify - # the docs in place, take care that the IDs reflect the original - # words. Currently we use the original docs to make the vectors - # for the target, so we don't lose the original tokens. But if - # you modified the docs in place here, you would. - masked_docs.append(Doc(doc.vocab, words=words, spaces=spaces)) - return mask, masked_docs - - -def replace_word(word, random_words, mask="[MASK]"): - roll = random.random() - if roll < 0.8: - return mask - elif roll < 0.9: - return random_words.next() - else: - return word - - -class RandomWords(object): - def __init__(self, vocab): - self.words = [lex.text for lex in vocab if lex.prob != 0.0] - self.probs = [lex.prob for lex in vocab if lex.prob != 0.0] - self.words = self.words[:10000] - self.probs = self.probs[:10000] - self.probs = numpy.exp(numpy.array(self.probs, dtype="f")) - self.probs /= self.probs.sum() - self._cache = [] - - def next(self): - if not self._cache: - self._cache.extend( - numpy.random.choice(len(self.words), 10000, p=self.probs) - ) - index = self._cache.pop() - return self.words[index] - - class ProgressTracker(object): def __init__(self, frequency=1000000): self.loss = 0.0 diff --git a/spacy/cli/train.py b/spacy/cli/train.py index acf26e548..dc205196a 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -25,6 +25,12 @@ from .. import about output_path=("Output directory to store model in", "positional", None, Path), train_path=("Location of JSON-formatted training data", "positional", None, Path), dev_path=("Location of JSON-formatted development data", "positional", None, Path), + raw_text=( + "Path to jsonl file with unlabelled text documents.", + "option", + "rt", + Path, + ), base_model=("Name of model to update (optional)", "option", "b", str), pipeline=("Comma-separated names of pipeline components", "option", "p", str), vectors=("Model to load vectors from", "option", "v", str), @@ -62,6 +68,7 @@ def train( output_path, train_path, dev_path, + raw_text=None, base_model=None, pipeline="tagger,parser,ner", vectors=None, @@ -92,6 +99,8 @@ def train( train_path = util.ensure_path(train_path) dev_path = util.ensure_path(dev_path) meta_path = util.ensure_path(meta_path) + if raw_text is not None: + raw_text = list(srsly.read_jsonl(raw_text)) if not train_path or not train_path.exists(): msg.fail("Training data not found", train_path, exits=1) if not dev_path or not dev_path.exists(): @@ -186,6 +195,8 @@ def train( optimizer.b1_decay = 0.0001 optimizer.b2_decay = 0.0001 nlp._optimizer = None + optimizer.b1_decay = 0.003 + optimizer.b2_decay = 0.003 # Load in pre-trained weights if init_tok2vec is not None: @@ -208,6 +219,11 @@ def train( train_docs = corpus.train_docs( nlp, noise_level=noise_level, gold_preproc=gold_preproc, max_length=0 ) + if raw_text: + random.shuffle(raw_text) + raw_batches = util.minibatch( + (nlp.make_doc(rt["text"]) for rt in raw_text), size=8 + ) words_seen = 0 with _create_progress_bar(n_train_words) as pbar: losses = {} @@ -222,7 +238,12 @@ def train( drop=next(dropout_rates), losses=losses, ) - if not int(os.environ.get('LOG_FRIENDLY', 0)): + if raw_text: + # If raw text is available, perform 'rehearsal' updates, + # which use unlabelled data to reduce overfitting. + raw_batch = list(next(raw_batches)) + nlp.rehearse(raw_batch, sgd=optimizer, losses=losses) + if not int(os.environ.get("LOG_FRIENDLY", 0)): pbar.update(sum(len(doc) for doc in docs)) words_seen += sum(len(doc) for doc in docs) with nlp.use_params(optimizer.averages): @@ -286,7 +307,7 @@ def train( @contextlib.contextmanager def _create_progress_bar(total): - if int(os.environ.get('LOG_FRIENDLY', 0)): + if int(os.environ.get("LOG_FRIENDLY", 0)): yield else: pbar = tqdm.tqdm(total=total, leave=False) diff --git a/spacy/language.py b/spacy/language.py index 4c3bfd5c8..2f0498c47 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -7,7 +7,7 @@ import weakref import functools from collections import OrderedDict from contextlib import contextmanager -from copy import copy +from copy import copy, deepcopy from thinc.neural import Model import srsly @@ -453,6 +453,59 @@ class Language(object): for key, (W, dW) in grads.items(): sgd(W, dW, key=key) + def rehearse(self, docs, sgd=None, losses=None, config=None): + """Make a "rehearsal" update to the models in the pipeline, to prevent + forgetting. Rehearsal updates run an initial copy of the model over some + data, and update the model so its current predictions are more like the + initial ones. This is useful for keeping a pre-trained model on-track, + even if you're updating it with a smaller set of examples. + + docs (iterable): A batch of `Doc` objects. + drop (float): The droput rate. + sgd (callable): An optimizer. + RETURNS (dict): Results from the update. + + EXAMPLE: + >>> raw_text_batches = minibatch(raw_texts) + >>> for labelled_batch in minibatch(zip(train_docs, train_golds)): + >>> docs, golds = zip(*train_docs) + >>> nlp.update(docs, golds) + >>> raw_batch = [nlp.make_doc(text) for text in next(raw_text_batches)] + >>> nlp.rehearse(raw_batch) + """ + if len(docs) == 0: + return + if sgd is None: + if self._optimizer is None: + self._optimizer = create_default_optimizer(Model.ops) + sgd = self._optimizer + docs = list(docs) + for i, doc in enumerate(docs): + if isinstance(doc, basestring_): + docs[i] = self.make_doc(doc) + pipes = list(self.pipeline) + random.shuffle(pipes) + if config is None: + config = {} + grads = {} + + def get_grads(W, dW, key=None): + grads[key] = (W, dW) + + get_grads.alpha = sgd.alpha + get_grads.b1 = sgd.b1 + get_grads.b2 = sgd.b2 + + for name, proc in pipes: + if not hasattr(proc, "rehearse"): + continue + grads = {} + proc.rehearse(docs, sgd=get_grads, losses=losses, **config.get(name, {})) + for key, (W, dW) in grads.items(): + sgd(W, dW, key=key) + + return losses + def preprocess_gold(self, docs_golds): """Can be called before training to pre-process gold data. By default, it handles nonprojectivity and adds missing tags to the tag map. @@ -499,6 +552,30 @@ class Language(object): ) return self._optimizer + def resume_training(self, sgd=None, **cfg): + """Continue training a pre-trained model. + + Create and return an optimizer, and initialize "rehearsal" for any pipeline + component that has a .rehearse() method. Rehearsal is used to prevent + models from "forgetting" their initialised "knowledge". To perform + rehearsal, collect samples of text you want the models to retain performance + on, and call nlp.rehearse() with a batch of Doc objects. + """ + if cfg.get("device", -1) >= 0: + util.use_gpu(cfg["device"]) + if self.vocab.vectors.data.shape[1] >= 1: + self.vocab.vectors.data = Model.ops.asarray(self.vocab.vectors.data) + link_vectors_to_models(self.vocab) + if self.vocab.vectors.data.shape[1]: + cfg["pretrained_vectors"] = self.vocab.vectors.name + if sgd is None: + sgd = create_default_optimizer(Model.ops) + self._optimizer = sgd + for name, proc in self.pipeline: + if hasattr(proc, "_rehearsal_model"): + proc._rehearsal_model = deepcopy(proc.model) + return self._optimizer + def evaluate(self, docs_golds, verbose=False): scorer = Scorer() docs, golds = zip(*docs_golds) diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 0f59acad5..ad646f25f 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -33,6 +33,7 @@ from ._ml import Tok2Vec, build_text_classifier, build_tagger_model from ._ml import build_simple_cnn_text_classifier from ._ml import link_vectors_to_models, zero_init, flatten from ._ml import create_default_optimizer +from ._ml import masked_language_model from .errors import Errors, TempErrors from .compat import basestring_ from . import util @@ -326,6 +327,9 @@ class Pipe(object): """ raise NotImplementedError + def rehearse(self, docs, sgd=None, losses=None, **config): + pass + def get_loss(self, docs, golds, scores): """Find the loss and gradient of loss for the batch of documents and their predicted scores.""" @@ -568,6 +572,7 @@ class Tagger(Pipe): def __init__(self, vocab, model=True, **cfg): self.vocab = vocab self.model = model + self._rehearsal_model = None self.cfg = OrderedDict(sorted(cfg.items())) self.cfg.setdefault('cnn_maxout_pieces', 2) @@ -649,6 +654,20 @@ class Tagger(Pipe): if losses is not None: losses[self.name] += loss + def rehearse(self, docs, drop=0., sgd=None, losses=None): + """Perform a 'rehearsal' update, where we try to match the output of + an initial model. + """ + if self._rehearsal_model is None: + return + guesses, backprop = self.model.begin_update(docs, drop=drop) + target = self._rehearsal_model(docs) + gradient = guesses - target + backprop(gradient, sgd=sgd) + if losses is not None: + losses.setdefault(self.name, 0.0) + losses[self.name] += (gradient**2).sum() + def get_loss(self, docs, golds, scores): scores = self.model.ops.flatten(scores) tag_index = {tag: i for i, tag in enumerate(self.labels)} @@ -986,6 +1005,69 @@ class MultitaskObjective(Tagger): return sent_tags[target] +class ClozeMultitask(Pipe): + @classmethod + def Model(cls, vocab, tok2vec, **cfg): + output_size = vocab.vectors.data.shape[1] + output_layer = chain( + LayerNorm(Maxout(output_size, tok2vec.nO, pieces=3)), + zero_init(Affine(output_size, output_size, drop_factor=0.0)) + ) + model = chain(tok2vec, output_layer) + model = masked_language_model(vocab, model) + model.tok2vec = tok2vec + model.output_layer = output_layer + return model + + def __init__(self, vocab, model=True, **cfg): + self.vocab = vocab + self.model = model + self.cfg = cfg + + def set_annotations(self, docs, dep_ids, tensors=None): + pass + + def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, + tok2vec=None, sgd=None, **kwargs): + link_vectors_to_models(self.vocab) + if self.model is True: + self.model = self.Model(self.vocab, tok2vec) + X = self.model.ops.allocate((5, self.model.tok2vec.nO)) + self.model.output_layer.begin_training(X) + if sgd is None: + sgd = self.create_optimizer() + return sgd + + def predict(self, docs): + tokvecs = self.model.tok2vec(docs) + vectors = self.model.output_layer(tokvecs) + return tokvecs, vectors + + def get_loss(self, docs, vectors, prediction): + # The simplest way to implement this would be to vstack the + # token.vector values, but that's a bit inefficient, especially on GPU. + # Instead we fetch the index into the vectors table for each of our tokens, + # and look them up all at once. This prevents data copying. + ids = self.model.ops.flatten([doc.to_array(ID).ravel() for doc in docs]) + target = vectors[ids] + gradient = (prediction - target) / prediction.shape[0] + loss = (gradient**2).sum() + return float(loss), gradient + + def update(self, docs, golds, drop=0., sgd=None, losses=None): + pass + + def rehearse(self, docs, drop=0., sgd=None, losses=None): + if losses is not None and self.name not in losses: + losses[self.name] = 0. + predictions, bp_predictions = self.model.begin_update(docs, drop=drop) + loss, d_predictions = self.get_loss(docs, self.vocab.vectors.data, predictions) + bp_predictions(d_predictions, sgd=sgd) + + if losses is not None: + losses[self.name] += loss + + class SimilarityHook(Pipe): """ Experimental: A pipeline component to install a hook for supervised @@ -1062,6 +1144,7 @@ class TextCategorizer(Pipe): def __init__(self, vocab, model=True, **cfg): self.vocab = vocab self.model = model + self._rehearsal_model = None self.cfg = dict(cfg) @property @@ -1103,6 +1186,17 @@ class TextCategorizer(Pipe): losses.setdefault(self.name, 0.0) losses[self.name] += loss + def rehearse(self, docs, drop=0., sgd=None, losses=None): + if self._rehearsal_model is None: + return + scores, bp_scores = self.model.begin_update(docs, drop=drop) + target = self._rehearsal_model(docs) + gradient = scores - target + bp_scores(gradient, sgd=sgd) + if losses is not None: + losses.setdefault(self.name, 0.0) + losses[self.name] += (gradient**2).sum() + def get_loss(self, docs, golds, scores): truths = numpy.zeros((len(golds), len(self.labels)), dtype='f') not_missing = numpy.ones((len(golds), len(self.labels)), dtype='f') @@ -1165,8 +1259,12 @@ cdef class DependencyParser(Parser): return [nonproj.deprojectivize] def add_multitask_objective(self, target): - labeller = MultitaskObjective(self.vocab, target=target) - self._multitasks.append(labeller) + if target == 'cloze': + cloze = ClozeMultitask(self.vocab) + self._multitasks.append(cloze) + else: + labeller = MultitaskObjective(self.vocab, target=target) + self._multitasks.append(labeller) def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg): for labeller in self._multitasks: @@ -1186,8 +1284,12 @@ cdef class EntityRecognizer(Parser): nr_feature = 6 def add_multitask_objective(self, target): - labeller = MultitaskObjective(self.vocab, target=target) - self._multitasks.append(labeller) + if target == 'cloze': + cloze = ClozeMultitask(self.vocab) + self._multitasks.append(cloze) + else: + labeller = MultitaskObjective(self.vocab, target=target) + self._multitasks.append(labeller) def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg): for labeller in self._multitasks: diff --git a/spacy/syntax/_parser_model.pyx b/spacy/syntax/_parser_model.pyx index f60354be8..5615fcea1 100644 --- a/spacy/syntax/_parser_model.pyx +++ b/spacy/syntax/_parser_model.pyx @@ -193,10 +193,6 @@ class ParserModel(Model): Model.__init__(self) self._layers = [tok2vec, lower_model, upper_model] - @property - def tok2vec(self): - return self._layers[0] - def begin_update(self, docs, drop=0.): step_model = ParserStepModel(docs, self._layers, drop=drop) def finish_parser_update(golds, sgd=None): @@ -205,13 +201,20 @@ class ParserModel(Model): return step_model, finish_parser_update def resize_output(self, new_output): + smaller = self.upper + larger = Affine(new_output, smaller.nI) + larger.W *= 0 + # It seems very unhappy if I pass these as smaller.W? + # Seems to segfault. Maybe it's a descriptor protocol thing? + smaller_W = smaller.W + larger_W = larger.W + smaller_b = smaller.b + larger_b = larger.b # Weights are stored in (nr_out, nr_in) format, so we're basically # just adding rows here. - smaller = self._layers[-1]._layers[-1] - larger = Affine(self.moves.n_moves, smaller.nI) - copy_array(larger.W[:smaller.nO], smaller.W) - copy_array(larger.b[:smaller.nO], smaller.b) - self._layers[-1]._layers[-1] = larger + larger_W[:smaller.nO] = smaller_W + larger_b[:smaller.nO] = smaller_b + self._layers[-1] = larger def begin_training(self, X, y=None): self.lower.begin_training(X, y=y) diff --git a/spacy/syntax/nn_parser.pxd b/spacy/syntax/nn_parser.pxd index 135096317..707c9654c 100644 --- a/spacy/syntax/nn_parser.pxd +++ b/spacy/syntax/nn_parser.pxd @@ -12,6 +12,7 @@ from ._parser_model cimport WeightsC, ActivationsC, SizesC cdef class Parser: cdef readonly Vocab vocab cdef public object model + cdef public object _rehearsal_model cdef readonly TransitionSystem moves cdef readonly object cfg cdef public object _multitasks @@ -21,4 +22,3 @@ cdef class Parser: cdef void c_transition_batch(self, StateC** states, const float* scores, int nr_class, int batch_size) nogil - diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index a8809b4e6..936ba8e8d 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -72,13 +72,15 @@ cdef class Parser: pretrained_vectors=pretrained_vectors, bilstm_depth=bilstm_depth) tok2vec = chain(tok2vec, flatten) + tok2vec.nO = token_vector_width lower = PrecomputableAffine(hidden_width, nF=cls.nr_feature, nI=token_vector_width, nP=parser_maxout_pieces) lower.nP = parser_maxout_pieces with Model.use_device('cpu'): - upper = zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) + upper = Affine(nr_class, hidden_width, drop_factor=0.0) + upper.W *= 0 cfg = { 'nr_class': nr_class, @@ -121,6 +123,7 @@ cdef class Parser: self.cfg = cfg self.model = model self._multitasks = [] + self._rehearsal_model = None def __reduce__(self): return (Parser, (self.vocab, self.moves, self.model), None, None) @@ -404,6 +407,43 @@ cdef class Parser: finish_update(golds, sgd=sgd) return losses + def rehearse(self, docs, sgd=None, losses=None, **cfg): + """Perform a "rehearsal" update, to prevent catastrophic forgetting.""" + if isinstance(docs, Doc): + docs = [docs] + if losses is None: + losses = {} + for multitask in self._multitasks: + if hasattr(multitask, 'rehearse'): + multitask.rehearse(docs, losses=losses, sgd=sgd) + if self._rehearsal_model is None: + return None + losses.setdefault(self.name, 0.) + # Prepare the stepwise model, and get the callback for finishing the batch + tutor = self._rehearsal_model(docs) + model, finish_update = self.model.begin_update(docs, drop=0.0) + states = self.moves.init_batch(docs) + n_scores = 0. + loss = 0. + non_zeroed_classes = self._rehearsal_model.upper.W.any(axis=1) + while states: + targets, _ = tutor.begin_update(states) + guesses, backprop = model.begin_update(states) + d_scores = (targets - guesses) / targets.shape[0] + d_scores *= non_zeroed_classes + # If all weights for an output are 0 in the original model, don't + # supervise that output. This allows us to add classes. + loss += (d_scores**2).sum() + backprop(d_scores, sgd=sgd) + # Follow the predicted action + self.transition_states(states, guesses) + states = [state for state in states if not state.is_final()] + n_scores += d_scores.size + # Do the backprop + finish_update(docs, sgd=sgd) + losses[self.name] += loss / n_scores + return losses + def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None, beam_density=0.0): lengths = [len(d) for d in docs] @@ -416,7 +456,7 @@ cdef class Parser: model.vec2scores, width, drop=drop, losses=losses, beam_density=beam_density) for i, d_scores in enumerate(states_d_scores): - losses[self.name] += (d_scores**2).sum() + losses[self.name] += (d_scores**2).mean() ids, bp_vectors, bp_scores = backprops[i] d_vector = bp_scores(d_scores, sgd=sgd) if isinstance(model.ops, CupyOps) \