diff --git a/spacy/language.py b/spacy/language.py index da58d1e76..a28f2a84e 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -449,29 +449,9 @@ class Language(object): def make_doc(self, text): return self.tokenizer(text) - def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None): - """Update the models in the pipeline. - - docs (iterable): A batch of `Doc` objects. - golds (iterable): A batch of `GoldParse` objects. - drop (float): The droput rate. - sgd (callable): An optimizer. - losses (dict): Dictionary to update with the loss, keyed by component. - component_cfg (dict): Config parameters for specific pipeline - components, keyed by component name. - - DOCS: https://spacy.io/api/language#update - """ + def _format_docs_and_golds(self, docs, golds): + """Format golds and docs before update models.""" expected_keys = ("words", "tags", "heads", "deps", "entities", "cats", "links") - if len(docs) != len(golds): - raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds))) - if len(docs) == 0: - return - if sgd is None: - if self._optimizer is None: - self._optimizer = create_default_optimizer(Model.ops) - sgd = self._optimizer - # Allow dict of args to GoldParse, instead of GoldParse objects. gold_objs = [] doc_objs = [] for doc, gold in zip(docs, golds): @@ -485,8 +465,32 @@ class Language(object): gold = GoldParse(doc, **gold) doc_objs.append(doc) gold_objs.append(gold) - golds = gold_objs - docs = doc_objs + + return doc_objs, gold_objs + + def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None): + """Update the models in the pipeline. + + docs (iterable): A batch of `Doc` objects. + golds (iterable): A batch of `GoldParse` objects. + drop (float): The droput rate. + sgd (callable): An optimizer. + losses (dict): Dictionary to update with the loss, keyed by component. + component_cfg (dict): Config parameters for specific pipeline + components, keyed by component name. + + DOCS: https://spacy.io/api/language#update + """ + if len(docs) != len(golds): + raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds))) + if len(docs) == 0: + return + if sgd is None: + if self._optimizer is None: + self._optimizer = create_default_optimizer(Model.ops) + sgd = self._optimizer + # Allow dict of args to GoldParse, instead of GoldParse objects. + docs, golds = self._format_docs_and_golds(docs, golds) grads = {} def get_grads(W, dW, key=None):