Allow component decorators to re-run with same function

This commit is contained in:
Ines Montani 2020-08-28 16:27:22 +02:00
parent 89f692bc8a
commit cad988da7f
4 changed files with 72 additions and 6 deletions

View File

@ -128,7 +128,8 @@ class Errors:
"got {component} (name: '{name}'). If you're using a custom " "got {component} (name: '{name}'). If you're using a custom "
"component factory, double-check that it correctly returns your " "component factory, double-check that it correctly returns your "
"initialized component.") "initialized component.")
E004 = ("Can't set up pipeline component: a factory for '{name}' already exists.") E004 = ("Can't set up pipeline component: a factory for '{name}' already "
"exists. Existing factory: {func}. New factory: {new_func}")
E005 = ("Pipeline component '{name}' returned None. If you're using a " E005 = ("Pipeline component '{name}' returned None. If you're using a "
"custom component, maybe you forgot to return the processed Doc?") "custom component, maybe you forgot to return the processed Doc?")
E006 = ("Invalid constraints for adding pipeline component. You can only " E006 = ("Invalid constraints for adding pipeline component. You can only "

View File

@ -396,13 +396,21 @@ class Language:
style="default config", name=name, cfg_type=type(default_config) style="default config", name=name, cfg_type=type(default_config)
) )
raise ValueError(err) raise ValueError(err)
internal_name = cls.get_factory_name(name)
if internal_name in registry.factories:
# We only check for the internal name here it's okay if it's a
# subclass and the base class has a factory of the same name
raise ValueError(Errors.E004.format(name=name))
def add_factory(factory_func: Callable) -> Callable: def add_factory(factory_func: Callable) -> Callable:
internal_name = cls.get_factory_name(name)
if internal_name in registry.factories:
# We only check for the internal name here it's okay if it's a
# subclass and the base class has a factory of the same name. We
# also only raise if the function is different to prevent raising
# if module is reloaded.
existing_func = registry.factories.get(internal_name)
if not util.is_same_func(factory_func, existing_func):
err = Errors.E004.format(
name=name, func=existing_func, new_func=factory_func
)
raise ValueError(err)
arg_names = util.get_arg_names(factory_func) arg_names = util.get_arg_names(factory_func)
if "nlp" not in arg_names or "name" not in arg_names: if "nlp" not in arg_names or "name" not in arg_names:
raise ValueError(Errors.E964.format(name=name)) raise ValueError(Errors.E964.format(name=name))
@ -472,6 +480,21 @@ class Language:
def factory_func(nlp: cls, name: str) -> Callable[[Doc], Doc]: def factory_func(nlp: cls, name: str) -> Callable[[Doc], Doc]:
return component_func return component_func
internal_name = cls.get_factory_name(name)
if internal_name in registry.factories:
# We only check for the internal name here it's okay if it's a
# subclass and the base class has a factory of the same name. We
# also only raise if the function is different to prevent raising
# if module is reloaded. It's hacky, but we need to check the
# existing functure for a closure and whether that's identical
# to the component function (because factory_func created above
# will always be different, even for the same function)
existing_func = registry.factories.get(internal_name)
closure = existing_func.__closure__
wrapped = [c.cell_contents for c in closure][0] if closure else None
if util.is_same_func(wrapped, component_func):
factory_func = existing_func # noqa: F811
cls.factory( cls.factory(
component_name, component_name,
assigns=assigns, assigns=assigns,

View File

@ -438,3 +438,26 @@ def test_pipe_factories_from_source_config():
config = nlp.config["components"]["custom"] config = nlp.config["components"]["custom"]
assert config["factory"] == name assert config["factory"] == name
assert config["arg"] == "world" assert config["arg"] == "world"
def test_pipe_factories_decorator_idempotent():
"""Check that decorator can be run multiple times if the function is the
same. This is especially relevant for live reloading because we don't
want spaCy to raise an error if a module registering components is reloaded.
"""
name = "test_pipe_factories_decorator_idempotent"
func = lambda nlp, name: lambda doc: doc
for i in range(5):
Language.factory(name, func=func)
nlp = Language()
nlp.add_pipe(name)
Language.factory(name, func=func)
# Make sure it also works for component decorator, which creates the
# factory function
name2 = f"{name}2"
func2 = lambda doc: doc
for i in range(5):
Language.component(name2, func=func2)
nlp = Language()
nlp.add_pipe(name)
Language.component(name2, func=func2)

View File

@ -673,6 +673,25 @@ def get_object_name(obj: Any) -> str:
return repr(obj) return repr(obj)
def is_same_func(func1: Callable, func2: Callable) -> bool:
"""Approximately decide whether two functions are the same, even if their
identity is different (e.g. after they have been live reloaded). Mostly
used in the @Language.component and @Language.factory decorators to decide
whether to raise if a factory already exists. Allows decorator to run
multiple times with the same function.
func1 (Callable): The first function.
func2 (Callable): The second function.
RETURNS (bool): Whether it's the same function (most likely).
"""
if not callable(func1) or not callable(func2):
return False
same_name = func1.__qualname__ == func2.__qualname__
same_file = inspect.getfile(func1) == inspect.getfile(func2)
same_code = inspect.getsourcelines(func1) == inspect.getsourcelines(func2)
return all([same_name, same_file, same_code])
def get_cuda_stream( def get_cuda_stream(
require: bool = False, non_blocking: bool = True require: bool = False, non_blocking: bool = True
) -> Optional[CudaStream]: ) -> Optional[CudaStream]: