Let Langugae.use_params work with falsey inputs

The Language.use_params method was failing if you passed in None, which
meant we had to use awkward conditionals for the parameter averaging.
This solves the problem.
This commit is contained in:
Matthew Honnibal 2020-09-03 12:51:04 +02:00
parent 122cb02001
commit ef0d0630a4
1 changed files with 23 additions and 20 deletions

View File

@ -1,5 +1,5 @@
from typing import Optional, Any, Dict, Callable, Iterable, Union, List, Pattern from typing import Optional, Any, Dict, Callable, Iterable, Union, List, Pattern
from typing import Tuple, Iterator from typing import Tuple, Iterator, Optional
from dataclasses import dataclass from dataclasses import dataclass
import random import random
import itertools import itertools
@ -1275,7 +1275,7 @@ class Language:
return results return results
@contextmanager @contextmanager
def use_params(self, params: dict): def use_params(self, params: Optional[dict]):
"""Replace weights of models in the pipeline with those provided in the """Replace weights of models in the pipeline with those provided in the
params dictionary. Can be used as a contextmanager, in which case, params dictionary. Can be used as a contextmanager, in which case,
models go back to their original weights after the block. models go back to their original weights after the block.
@ -1288,6 +1288,9 @@ class Language:
DOCS: https://spacy.io/api/language#use_params DOCS: https://spacy.io/api/language#use_params
""" """
if not params:
yield
else:
contexts = [ contexts = [
pipe.use_params(params) pipe.use_params(params)
for name, pipe in self.pipeline for name, pipe in self.pipeline