mirror of https://github.com/explosion/spaCy.git
Refactor language update (#4316)
* refactor: separate formatting docs and golds in Language.update * fix return typo
This commit is contained in:
parent
105a91975b
commit
b408b5b29e
|
@ -449,29 +449,9 @@ class Language(object):
|
||||||
def make_doc(self, text):
|
def make_doc(self, text):
|
||||||
return self.tokenizer(text)
|
return self.tokenizer(text)
|
||||||
|
|
||||||
def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
|
def _format_docs_and_golds(self, docs, golds):
|
||||||
"""Update the models in the pipeline.
|
"""Format golds and docs before update models."""
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
|
||||||
golds (iterable): A batch of `GoldParse` objects.
|
|
||||||
drop (float): The droput rate.
|
|
||||||
sgd (callable): An optimizer.
|
|
||||||
losses (dict): Dictionary to update with the loss, keyed by component.
|
|
||||||
component_cfg (dict): Config parameters for specific pipeline
|
|
||||||
components, keyed by component name.
|
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#update
|
|
||||||
"""
|
|
||||||
expected_keys = ("words", "tags", "heads", "deps", "entities", "cats", "links")
|
expected_keys = ("words", "tags", "heads", "deps", "entities", "cats", "links")
|
||||||
if len(docs) != len(golds):
|
|
||||||
raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds)))
|
|
||||||
if len(docs) == 0:
|
|
||||||
return
|
|
||||||
if sgd is None:
|
|
||||||
if self._optimizer is None:
|
|
||||||
self._optimizer = create_default_optimizer(Model.ops)
|
|
||||||
sgd = self._optimizer
|
|
||||||
# Allow dict of args to GoldParse, instead of GoldParse objects.
|
|
||||||
gold_objs = []
|
gold_objs = []
|
||||||
doc_objs = []
|
doc_objs = []
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
|
@ -485,8 +465,32 @@ class Language(object):
|
||||||
gold = GoldParse(doc, **gold)
|
gold = GoldParse(doc, **gold)
|
||||||
doc_objs.append(doc)
|
doc_objs.append(doc)
|
||||||
gold_objs.append(gold)
|
gold_objs.append(gold)
|
||||||
golds = gold_objs
|
|
||||||
docs = doc_objs
|
return doc_objs, gold_objs
|
||||||
|
|
||||||
|
def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
|
||||||
|
"""Update the models in the pipeline.
|
||||||
|
|
||||||
|
docs (iterable): A batch of `Doc` objects.
|
||||||
|
golds (iterable): A batch of `GoldParse` objects.
|
||||||
|
drop (float): The droput rate.
|
||||||
|
sgd (callable): An optimizer.
|
||||||
|
losses (dict): Dictionary to update with the loss, keyed by component.
|
||||||
|
component_cfg (dict): Config parameters for specific pipeline
|
||||||
|
components, keyed by component name.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/language#update
|
||||||
|
"""
|
||||||
|
if len(docs) != len(golds):
|
||||||
|
raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds)))
|
||||||
|
if len(docs) == 0:
|
||||||
|
return
|
||||||
|
if sgd is None:
|
||||||
|
if self._optimizer is None:
|
||||||
|
self._optimizer = create_default_optimizer(Model.ops)
|
||||||
|
sgd = self._optimizer
|
||||||
|
# Allow dict of args to GoldParse, instead of GoldParse objects.
|
||||||
|
docs, golds = self._format_docs_and_golds(docs, golds)
|
||||||
grads = {}
|
grads = {}
|
||||||
|
|
||||||
def get_grads(W, dW, key=None):
|
def get_grads(W, dW, key=None):
|
||||||
|
|
Loading…
Reference in New Issue