mirror of https://github.com/explosion/spaCy.git
Refactor language update (#4316)
* refactor: separate formatting docs and golds in Language.update * fix return typo
This commit is contained in:
parent
105a91975b
commit
b408b5b29e
|
@ -449,29 +449,9 @@ class Language(object):
|
|||
def make_doc(self, text):
|
||||
return self.tokenizer(text)
|
||||
|
||||
def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
|
||||
"""Update the models in the pipeline.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
golds (iterable): A batch of `GoldParse` objects.
|
||||
drop (float): The droput rate.
|
||||
sgd (callable): An optimizer.
|
||||
losses (dict): Dictionary to update with the loss, keyed by component.
|
||||
component_cfg (dict): Config parameters for specific pipeline
|
||||
components, keyed by component name.
|
||||
|
||||
DOCS: https://spacy.io/api/language#update
|
||||
"""
|
||||
def _format_docs_and_golds(self, docs, golds):
|
||||
"""Format golds and docs before update models."""
|
||||
expected_keys = ("words", "tags", "heads", "deps", "entities", "cats", "links")
|
||||
if len(docs) != len(golds):
|
||||
raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds)))
|
||||
if len(docs) == 0:
|
||||
return
|
||||
if sgd is None:
|
||||
if self._optimizer is None:
|
||||
self._optimizer = create_default_optimizer(Model.ops)
|
||||
sgd = self._optimizer
|
||||
# Allow dict of args to GoldParse, instead of GoldParse objects.
|
||||
gold_objs = []
|
||||
doc_objs = []
|
||||
for doc, gold in zip(docs, golds):
|
||||
|
@ -485,8 +465,32 @@ class Language(object):
|
|||
gold = GoldParse(doc, **gold)
|
||||
doc_objs.append(doc)
|
||||
gold_objs.append(gold)
|
||||
golds = gold_objs
|
||||
docs = doc_objs
|
||||
|
||||
return doc_objs, gold_objs
|
||||
|
||||
def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
|
||||
"""Update the models in the pipeline.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
golds (iterable): A batch of `GoldParse` objects.
|
||||
drop (float): The droput rate.
|
||||
sgd (callable): An optimizer.
|
||||
losses (dict): Dictionary to update with the loss, keyed by component.
|
||||
component_cfg (dict): Config parameters for specific pipeline
|
||||
components, keyed by component name.
|
||||
|
||||
DOCS: https://spacy.io/api/language#update
|
||||
"""
|
||||
if len(docs) != len(golds):
|
||||
raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds)))
|
||||
if len(docs) == 0:
|
||||
return
|
||||
if sgd is None:
|
||||
if self._optimizer is None:
|
||||
self._optimizer = create_default_optimizer(Model.ops)
|
||||
sgd = self._optimizer
|
||||
# Allow dict of args to GoldParse, instead of GoldParse objects.
|
||||
docs, golds = self._format_docs_and_golds(docs, golds)
|
||||
grads = {}
|
||||
|
||||
def get_grads(W, dW, key=None):
|
||||
|
|
Loading…
Reference in New Issue