mirror of https://github.com/explosion/spaCy.git
Remove 'device' argument from Language, clean up 'sgd' arg
This commit is contained in:
parent
ff9a63bfbd
commit
5276db6f3f
|
@ -19,7 +19,7 @@ from .vocab import Vocab, create_vocab
|
||||||
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
|
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
|
||||||
from .training import Example, validate_examples
|
from .training import Example, validate_examples
|
||||||
from .scorer import Scorer
|
from .scorer import Scorer
|
||||||
from .util import create_default_optimizer, registry, SimpleFrozenList
|
from .util import registry, SimpleFrozenList
|
||||||
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
|
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
|
||||||
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
||||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
||||||
|
@ -1065,7 +1065,7 @@ class Language:
|
||||||
validate_examples(examples, "Language.update")
|
validate_examples(examples, "Language.update")
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
if self._optimizer is None:
|
if self._optimizer is None:
|
||||||
self._optimizer = create_default_optimizer()
|
self._optimizer = self.create_optimizer()
|
||||||
sgd = self._optimizer
|
sgd = self._optimizer
|
||||||
if component_cfg is None:
|
if component_cfg is None:
|
||||||
component_cfg = {}
|
component_cfg = {}
|
||||||
|
@ -1123,7 +1123,7 @@ class Language:
|
||||||
validate_examples(examples, "Language.rehearse")
|
validate_examples(examples, "Language.rehearse")
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
if self._optimizer is None:
|
if self._optimizer is None:
|
||||||
self._optimizer = create_default_optimizer()
|
self._optimizer = self.create_optimizer()
|
||||||
sgd = self._optimizer
|
sgd = self._optimizer
|
||||||
pipes = list(self.pipeline)
|
pipes = list(self.pipeline)
|
||||||
random.shuffle(pipes)
|
random.shuffle(pipes)
|
||||||
|
@ -1161,16 +1161,14 @@ class Language:
|
||||||
def initialize(
|
def initialize(
|
||||||
self,
|
self,
|
||||||
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
|
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
|
||||||
*,
|
sgd: Optional[Optimizer]=None
|
||||||
sgd: Optional[Optimizer] = None,
|
) -> None:
|
||||||
device: int = -1,
|
|
||||||
) -> Optimizer:
|
|
||||||
"""Initialize the pipe for training, using data examples if available.
|
"""Initialize the pipe for training, using data examples if available.
|
||||||
|
|
||||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||||
returns gold-standard Example objects.
|
returns gold-standard Example objects.
|
||||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
sgd (Optional[Optimizer]): An optimizer to use for updates. If not
|
||||||
create_optimizer if it doesn't exist.
|
provided, will be created using the .create_optimizer() method.
|
||||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/language#initialize
|
DOCS: https://nightly.spacy.io/api/language#initialize
|
||||||
|
@ -1199,25 +1197,22 @@ class Language:
|
||||||
if not valid_examples:
|
if not valid_examples:
|
||||||
err = Errors.E930.format(name="Language", obj="empty list")
|
err = Errors.E930.format(name="Language", obj="empty list")
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
if device >= 0: # TODO: do we need this here?
|
if self.vocab.vectors.data.shape[1] >= 1:
|
||||||
require_gpu(device)
|
ops = get_current_ops()
|
||||||
if self.vocab.vectors.data.shape[1] >= 1:
|
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||||
ops = get_current_ops()
|
|
||||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
|
||||||
if sgd is None:
|
|
||||||
sgd = create_default_optimizer()
|
|
||||||
self._optimizer = sgd
|
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if hasattr(proc, "initialize"):
|
if hasattr(proc, "initialize"):
|
||||||
proc.initialize(
|
proc.initialize(
|
||||||
get_examples, pipeline=self.pipeline, sgd=self._optimizer
|
get_examples, pipeline=self.pipeline
|
||||||
)
|
)
|
||||||
self._link_components()
|
self._link_components()
|
||||||
|
if sgd is not None:
|
||||||
|
self._optimizer = sgd
|
||||||
|
elif self._optimizer is None:
|
||||||
|
self._optimizer = self.create_optimizer()
|
||||||
return self._optimizer
|
return self._optimizer
|
||||||
|
|
||||||
def resume_training(
|
def resume_training(self, *, sgd: Optional[Optimizer] = None) -> Optimizer:
|
||||||
self, *, sgd: Optional[Optimizer] = None, device: int = -1
|
|
||||||
) -> Optimizer:
|
|
||||||
"""Continue training a pretrained model.
|
"""Continue training a pretrained model.
|
||||||
|
|
||||||
Create and return an optimizer, and initialize "rehearsal" for any pipeline
|
Create and return an optimizer, and initialize "rehearsal" for any pipeline
|
||||||
|
@ -1226,22 +1221,20 @@ class Language:
|
||||||
rehearsal, collect samples of text you want the models to retain performance
|
rehearsal, collect samples of text you want the models to retain performance
|
||||||
on, and call nlp.rehearse() with a batch of Example objects.
|
on, and call nlp.rehearse() with a batch of Example objects.
|
||||||
|
|
||||||
sgd (Optional[Optimizer]): An optimizer.
|
|
||||||
RETURNS (Optimizer): The optimizer.
|
RETURNS (Optimizer): The optimizer.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/language#resume_training
|
DOCS: https://nightly.spacy.io/api/language#resume_training
|
||||||
"""
|
"""
|
||||||
if device >= 0: # TODO: do we need this here?
|
ops = get_current_ops()
|
||||||
require_gpu(device)
|
if self.vocab.vectors.data.shape[1] >= 1:
|
||||||
ops = get_current_ops()
|
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||||
if self.vocab.vectors.data.shape[1] >= 1:
|
|
||||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
|
||||||
if sgd is None:
|
|
||||||
sgd = create_default_optimizer()
|
|
||||||
self._optimizer = sgd
|
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if hasattr(proc, "_rehearsal_model"):
|
if hasattr(proc, "_rehearsal_model"):
|
||||||
proc._rehearsal_model = deepcopy(proc.model)
|
proc._rehearsal_model = deepcopy(proc.model)
|
||||||
|
if sgd is not None:
|
||||||
|
self._optimizer = sgd
|
||||||
|
elif self._optimizer is None:
|
||||||
|
self._optimizer = self.create_optimizer()
|
||||||
return self._optimizer
|
return self._optimizer
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
|
@ -1303,6 +1296,10 @@ class Language:
|
||||||
results["speed"] = n_words / (end_time - start_time)
|
results["speed"] = n_words / (end_time - start_time)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
"""Create an optimizer, usually using the [training.optimizer] config."""
|
||||||
|
return registry.resolve(self.config["training"]["optimizer"])
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_params(self, params: Optional[dict]):
|
def use_params(self, params: Optional[dict]):
|
||||||
"""Replace weights of models in the pipeline with those provided in the
|
"""Replace weights of models in the pipeline with those provided in the
|
||||||
|
|
Loading…
Reference in New Issue