Remove 'device' argument from Language, clean up 'sgd' arg

This commit is contained in:
Matthew Honnibal 2020-09-29 11:42:19 +02:00
parent ff9a63bfbd
commit 5276db6f3f
1 changed files with 27 additions and 30 deletions

View File

@ -19,7 +19,7 @@ from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .training import Example, validate_examples from .training import Example, validate_examples
from .scorer import Scorer from .scorer import Scorer
from .util import create_default_optimizer, registry, SimpleFrozenList from .util import registry, SimpleFrozenList
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
@ -1065,7 +1065,7 @@ class Language:
validate_examples(examples, "Language.update") validate_examples(examples, "Language.update")
if sgd is None: if sgd is None:
if self._optimizer is None: if self._optimizer is None:
self._optimizer = create_default_optimizer() self._optimizer = self.create_optimizer()
sgd = self._optimizer sgd = self._optimizer
if component_cfg is None: if component_cfg is None:
component_cfg = {} component_cfg = {}
@ -1123,7 +1123,7 @@ class Language:
validate_examples(examples, "Language.rehearse") validate_examples(examples, "Language.rehearse")
if sgd is None: if sgd is None:
if self._optimizer is None: if self._optimizer is None:
self._optimizer = create_default_optimizer() self._optimizer = self.create_optimizer()
sgd = self._optimizer sgd = self._optimizer
pipes = list(self.pipeline) pipes = list(self.pipeline)
random.shuffle(pipes) random.shuffle(pipes)
@ -1161,16 +1161,14 @@ class Language:
def initialize( def initialize(
self, self,
get_examples: Optional[Callable[[], Iterable[Example]]] = None, get_examples: Optional[Callable[[], Iterable[Example]]] = None,
*, sgd: Optional[Optimizer]=None
sgd: Optional[Optimizer] = None, ) -> None:
device: int = -1,
) -> Optimizer:
"""Initialize the pipe for training, using data examples if available. """Initialize the pipe for training, using data examples if available.
get_examples (Callable[[], Iterable[Example]]): Optional function that get_examples (Callable[[], Iterable[Example]]): Optional function that
returns gold-standard Example objects. returns gold-standard Example objects.
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with sgd (Optional[Optimizer]): An optimizer to use for updates. If not
create_optimizer if it doesn't exist. provided, will be created using the .create_optimizer() method.
RETURNS (thinc.api.Optimizer): The optimizer. RETURNS (thinc.api.Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/language#initialize DOCS: https://nightly.spacy.io/api/language#initialize
@ -1199,25 +1197,22 @@ class Language:
if not valid_examples: if not valid_examples:
err = Errors.E930.format(name="Language", obj="empty list") err = Errors.E930.format(name="Language", obj="empty list")
raise ValueError(err) raise ValueError(err)
if device >= 0: # TODO: do we need this here? if self.vocab.vectors.data.shape[1] >= 1:
require_gpu(device) ops = get_current_ops()
if self.vocab.vectors.data.shape[1] >= 1: self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
ops = get_current_ops()
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
if sgd is None:
sgd = create_default_optimizer()
self._optimizer = sgd
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "initialize"): if hasattr(proc, "initialize"):
proc.initialize( proc.initialize(
get_examples, pipeline=self.pipeline, sgd=self._optimizer get_examples, pipeline=self.pipeline
) )
self._link_components() self._link_components()
if sgd is not None:
self._optimizer = sgd
elif self._optimizer is None:
self._optimizer = self.create_optimizer()
return self._optimizer return self._optimizer
def resume_training( def resume_training(self, *, sgd: Optional[Optimizer] = None) -> Optimizer:
self, *, sgd: Optional[Optimizer] = None, device: int = -1
) -> Optimizer:
"""Continue training a pretrained model. """Continue training a pretrained model.
Create and return an optimizer, and initialize "rehearsal" for any pipeline Create and return an optimizer, and initialize "rehearsal" for any pipeline
@ -1226,22 +1221,20 @@ class Language:
rehearsal, collect samples of text you want the models to retain performance rehearsal, collect samples of text you want the models to retain performance
on, and call nlp.rehearse() with a batch of Example objects. on, and call nlp.rehearse() with a batch of Example objects.
sgd (Optional[Optimizer]): An optimizer.
RETURNS (Optimizer): The optimizer. RETURNS (Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/language#resume_training DOCS: https://nightly.spacy.io/api/language#resume_training
""" """
if device >= 0: # TODO: do we need this here? ops = get_current_ops()
require_gpu(device) if self.vocab.vectors.data.shape[1] >= 1:
ops = get_current_ops() self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
if self.vocab.vectors.data.shape[1] >= 1:
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
if sgd is None:
sgd = create_default_optimizer()
self._optimizer = sgd
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "_rehearsal_model"): if hasattr(proc, "_rehearsal_model"):
proc._rehearsal_model = deepcopy(proc.model) proc._rehearsal_model = deepcopy(proc.model)
if sgd is not None:
self._optimizer = sgd
elif self._optimizer is None:
self._optimizer = self.create_optimizer()
return self._optimizer return self._optimizer
def evaluate( def evaluate(
@ -1303,6 +1296,10 @@ class Language:
results["speed"] = n_words / (end_time - start_time) results["speed"] = n_words / (end_time - start_time)
return results return results
def create_optimizer(self):
"""Create an optimizer, usually using the [training.optimizer] config."""
return registry.resolve(self.config["training"]["optimizer"])
@contextmanager @contextmanager
def use_params(self, params: Optional[dict]): def use_params(self, params: Optional[dict]):
"""Replace weights of models in the pipeline with those provided in the """Replace weights of models in the pipeline with those provided in the