Let Langugae.use_params work with falsey inputs

The Language.use_params method was failing if you passed in None, which
meant we had to use awkward conditionals for the parameter averaging.
This solves the problem.
This commit is contained in:
Matthew Honnibal 2020-09-03 12:51:04 +02:00
parent 122cb02001
commit ef0d0630a4
1 changed files with 23 additions and 20 deletions

View File

@ -1,5 +1,5 @@
from typing import Optional, Any, Dict, Callable, Iterable, Union, List, Pattern
from typing import Tuple, Iterator
from typing import Tuple, Iterator, Optional
from dataclasses import dataclass
import random
import itertools
@ -1275,7 +1275,7 @@ class Language:
return results
@contextmanager
def use_params(self, params: dict):
def use_params(self, params: Optional[dict]):
"""Replace weights of models in the pipeline with those provided in the
params dictionary. Can be used as a contextmanager, in which case,
models go back to their original weights after the block.
@ -1288,24 +1288,27 @@ class Language:
DOCS: https://spacy.io/api/language#use_params
"""
contexts = [
pipe.use_params(params)
for name, pipe in self.pipeline
if hasattr(pipe, "use_params") and hasattr(pipe, "model")
]
# TODO: Having trouble with contextlib
# Workaround: these aren't actually context managers atm.
for context in contexts:
try:
next(context)
except StopIteration:
pass
yield
for context in contexts:
try:
next(context)
except StopIteration:
pass
if not params:
yield
else:
contexts = [
pipe.use_params(params)
for name, pipe in self.pipeline
if hasattr(pipe, "use_params") and hasattr(pipe, "model")
]
# TODO: Having trouble with contextlib
# Workaround: these aren't actually context managers atm.
for context in contexts:
try:
next(context)
except StopIteration:
pass
yield
for context in contexts:
try:
next(context)
except StopIteration:
pass
def pipe(
self,