Refactor language update (#4316)

* refactor: separate formatting docs and golds in Language.update

* fix return typo
This commit is contained in:
tamuhey 2019-09-27 23:20:21 +09:00 committed by Ines Montani
parent 105a91975b
commit b408b5b29e
1 changed files with 28 additions and 24 deletions

View File

@ -449,29 +449,9 @@ class Language(object):
def make_doc(self, text): def make_doc(self, text):
return self.tokenizer(text) return self.tokenizer(text)
def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None): def _format_docs_and_golds(self, docs, golds):
"""Update the models in the pipeline. """Format golds and docs before update models."""
docs (iterable): A batch of `Doc` objects.
golds (iterable): A batch of `GoldParse` objects.
drop (float): The droput rate.
sgd (callable): An optimizer.
losses (dict): Dictionary to update with the loss, keyed by component.
component_cfg (dict): Config parameters for specific pipeline
components, keyed by component name.
DOCS: https://spacy.io/api/language#update
"""
expected_keys = ("words", "tags", "heads", "deps", "entities", "cats", "links") expected_keys = ("words", "tags", "heads", "deps", "entities", "cats", "links")
if len(docs) != len(golds):
raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds)))
if len(docs) == 0:
return
if sgd is None:
if self._optimizer is None:
self._optimizer = create_default_optimizer(Model.ops)
sgd = self._optimizer
# Allow dict of args to GoldParse, instead of GoldParse objects.
gold_objs = [] gold_objs = []
doc_objs = [] doc_objs = []
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
@ -485,8 +465,32 @@ class Language(object):
gold = GoldParse(doc, **gold) gold = GoldParse(doc, **gold)
doc_objs.append(doc) doc_objs.append(doc)
gold_objs.append(gold) gold_objs.append(gold)
golds = gold_objs
docs = doc_objs return doc_objs, gold_objs
def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
"""Update the models in the pipeline.
docs (iterable): A batch of `Doc` objects.
golds (iterable): A batch of `GoldParse` objects.
drop (float): The droput rate.
sgd (callable): An optimizer.
losses (dict): Dictionary to update with the loss, keyed by component.
component_cfg (dict): Config parameters for specific pipeline
components, keyed by component name.
DOCS: https://spacy.io/api/language#update
"""
if len(docs) != len(golds):
raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds)))
if len(docs) == 0:
return
if sgd is None:
if self._optimizer is None:
self._optimizer = create_default_optimizer(Model.ops)
sgd = self._optimizer
# Allow dict of args to GoldParse, instead of GoldParse objects.
docs, golds = self._format_docs_and_golds(docs, golds)
grads = {} grads = {}
def get_grads(W, dW, key=None): def get_grads(W, dW, key=None):