diff --git a/spacy/language.py b/spacy/language.py index 8e7c39b90..7a354ee3d 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1,5 +1,5 @@ from typing import Optional, Any, Dict, Callable, Iterable, Union, List, Pattern -from typing import Tuple, Iterator +from typing import Tuple, Iterator, Optional from dataclasses import dataclass import random import itertools @@ -1275,7 +1275,7 @@ class Language: return results @contextmanager - def use_params(self, params: dict): + def use_params(self, params: Optional[dict]): """Replace weights of models in the pipeline with those provided in the params dictionary. Can be used as a contextmanager, in which case, models go back to their original weights after the block. @@ -1288,24 +1288,27 @@ class Language: DOCS: https://spacy.io/api/language#use_params """ - contexts = [ - pipe.use_params(params) - for name, pipe in self.pipeline - if hasattr(pipe, "use_params") and hasattr(pipe, "model") - ] - # TODO: Having trouble with contextlib - # Workaround: these aren't actually context managers atm. - for context in contexts: - try: - next(context) - except StopIteration: - pass - yield - for context in contexts: - try: - next(context) - except StopIteration: - pass + if not params: + yield + else: + contexts = [ + pipe.use_params(params) + for name, pipe in self.pipeline + if hasattr(pipe, "use_params") and hasattr(pipe, "model") + ] + # TODO: Having trouble with contextlib + # Workaround: these aren't actually context managers atm. + for context in contexts: + try: + next(context) + except StopIteration: + pass + yield + for context in contexts: + try: + next(context) + except StopIteration: + pass def pipe( self,