mirror of https://github.com/explosion/spaCy.git
Allow component decorators to re-run with same function
This commit is contained in:
parent
89f692bc8a
commit
cad988da7f
|
@ -128,7 +128,8 @@ class Errors:
|
|||
"got {component} (name: '{name}'). If you're using a custom "
|
||||
"component factory, double-check that it correctly returns your "
|
||||
"initialized component.")
|
||||
E004 = ("Can't set up pipeline component: a factory for '{name}' already exists.")
|
||||
E004 = ("Can't set up pipeline component: a factory for '{name}' already "
|
||||
"exists. Existing factory: {func}. New factory: {new_func}")
|
||||
E005 = ("Pipeline component '{name}' returned None. If you're using a "
|
||||
"custom component, maybe you forgot to return the processed Doc?")
|
||||
E006 = ("Invalid constraints for adding pipeline component. You can only "
|
||||
|
|
|
@ -396,13 +396,21 @@ class Language:
|
|||
style="default config", name=name, cfg_type=type(default_config)
|
||||
)
|
||||
raise ValueError(err)
|
||||
internal_name = cls.get_factory_name(name)
|
||||
if internal_name in registry.factories:
|
||||
# We only check for the internal name here – it's okay if it's a
|
||||
# subclass and the base class has a factory of the same name
|
||||
raise ValueError(Errors.E004.format(name=name))
|
||||
|
||||
def add_factory(factory_func: Callable) -> Callable:
|
||||
internal_name = cls.get_factory_name(name)
|
||||
if internal_name in registry.factories:
|
||||
# We only check for the internal name here – it's okay if it's a
|
||||
# subclass and the base class has a factory of the same name. We
|
||||
# also only raise if the function is different to prevent raising
|
||||
# if module is reloaded.
|
||||
existing_func = registry.factories.get(internal_name)
|
||||
if not util.is_same_func(factory_func, existing_func):
|
||||
err = Errors.E004.format(
|
||||
name=name, func=existing_func, new_func=factory_func
|
||||
)
|
||||
raise ValueError(err)
|
||||
|
||||
arg_names = util.get_arg_names(factory_func)
|
||||
if "nlp" not in arg_names or "name" not in arg_names:
|
||||
raise ValueError(Errors.E964.format(name=name))
|
||||
|
@ -472,6 +480,21 @@ class Language:
|
|||
def factory_func(nlp: cls, name: str) -> Callable[[Doc], Doc]:
|
||||
return component_func
|
||||
|
||||
internal_name = cls.get_factory_name(name)
|
||||
if internal_name in registry.factories:
|
||||
# We only check for the internal name here – it's okay if it's a
|
||||
# subclass and the base class has a factory of the same name. We
|
||||
# also only raise if the function is different to prevent raising
|
||||
# if module is reloaded. It's hacky, but we need to check the
|
||||
# existing functure for a closure and whether that's identical
|
||||
# to the component function (because factory_func created above
|
||||
# will always be different, even for the same function)
|
||||
existing_func = registry.factories.get(internal_name)
|
||||
closure = existing_func.__closure__
|
||||
wrapped = [c.cell_contents for c in closure][0] if closure else None
|
||||
if util.is_same_func(wrapped, component_func):
|
||||
factory_func = existing_func # noqa: F811
|
||||
|
||||
cls.factory(
|
||||
component_name,
|
||||
assigns=assigns,
|
||||
|
|
|
@ -438,3 +438,26 @@ def test_pipe_factories_from_source_config():
|
|||
config = nlp.config["components"]["custom"]
|
||||
assert config["factory"] == name
|
||||
assert config["arg"] == "world"
|
||||
|
||||
|
||||
def test_pipe_factories_decorator_idempotent():
|
||||
"""Check that decorator can be run multiple times if the function is the
|
||||
same. This is especially relevant for live reloading because we don't
|
||||
want spaCy to raise an error if a module registering components is reloaded.
|
||||
"""
|
||||
name = "test_pipe_factories_decorator_idempotent"
|
||||
func = lambda nlp, name: lambda doc: doc
|
||||
for i in range(5):
|
||||
Language.factory(name, func=func)
|
||||
nlp = Language()
|
||||
nlp.add_pipe(name)
|
||||
Language.factory(name, func=func)
|
||||
# Make sure it also works for component decorator, which creates the
|
||||
# factory function
|
||||
name2 = f"{name}2"
|
||||
func2 = lambda doc: doc
|
||||
for i in range(5):
|
||||
Language.component(name2, func=func2)
|
||||
nlp = Language()
|
||||
nlp.add_pipe(name)
|
||||
Language.component(name2, func=func2)
|
||||
|
|
|
@ -673,6 +673,25 @@ def get_object_name(obj: Any) -> str:
|
|||
return repr(obj)
|
||||
|
||||
|
||||
def is_same_func(func1: Callable, func2: Callable) -> bool:
|
||||
"""Approximately decide whether two functions are the same, even if their
|
||||
identity is different (e.g. after they have been live reloaded). Mostly
|
||||
used in the @Language.component and @Language.factory decorators to decide
|
||||
whether to raise if a factory already exists. Allows decorator to run
|
||||
multiple times with the same function.
|
||||
|
||||
func1 (Callable): The first function.
|
||||
func2 (Callable): The second function.
|
||||
RETURNS (bool): Whether it's the same function (most likely).
|
||||
"""
|
||||
if not callable(func1) or not callable(func2):
|
||||
return False
|
||||
same_name = func1.__qualname__ == func2.__qualname__
|
||||
same_file = inspect.getfile(func1) == inspect.getfile(func2)
|
||||
same_code = inspect.getsourcelines(func1) == inspect.getsourcelines(func2)
|
||||
return all([same_name, same_file, same_code])
|
||||
|
||||
|
||||
def get_cuda_stream(
|
||||
require: bool = False, non_blocking: bool = True
|
||||
) -> Optional[CudaStream]:
|
||||
|
|
Loading…
Reference in New Issue