mirror of https://github.com/explosion/spaCy.git
Allow component decorators to re-run with same function
This commit is contained in:
parent
89f692bc8a
commit
cad988da7f
|
@ -128,7 +128,8 @@ class Errors:
|
||||||
"got {component} (name: '{name}'). If you're using a custom "
|
"got {component} (name: '{name}'). If you're using a custom "
|
||||||
"component factory, double-check that it correctly returns your "
|
"component factory, double-check that it correctly returns your "
|
||||||
"initialized component.")
|
"initialized component.")
|
||||||
E004 = ("Can't set up pipeline component: a factory for '{name}' already exists.")
|
E004 = ("Can't set up pipeline component: a factory for '{name}' already "
|
||||||
|
"exists. Existing factory: {func}. New factory: {new_func}")
|
||||||
E005 = ("Pipeline component '{name}' returned None. If you're using a "
|
E005 = ("Pipeline component '{name}' returned None. If you're using a "
|
||||||
"custom component, maybe you forgot to return the processed Doc?")
|
"custom component, maybe you forgot to return the processed Doc?")
|
||||||
E006 = ("Invalid constraints for adding pipeline component. You can only "
|
E006 = ("Invalid constraints for adding pipeline component. You can only "
|
||||||
|
|
|
@ -396,13 +396,21 @@ class Language:
|
||||||
style="default config", name=name, cfg_type=type(default_config)
|
style="default config", name=name, cfg_type=type(default_config)
|
||||||
)
|
)
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
internal_name = cls.get_factory_name(name)
|
|
||||||
if internal_name in registry.factories:
|
|
||||||
# We only check for the internal name here – it's okay if it's a
|
|
||||||
# subclass and the base class has a factory of the same name
|
|
||||||
raise ValueError(Errors.E004.format(name=name))
|
|
||||||
|
|
||||||
def add_factory(factory_func: Callable) -> Callable:
|
def add_factory(factory_func: Callable) -> Callable:
|
||||||
|
internal_name = cls.get_factory_name(name)
|
||||||
|
if internal_name in registry.factories:
|
||||||
|
# We only check for the internal name here – it's okay if it's a
|
||||||
|
# subclass and the base class has a factory of the same name. We
|
||||||
|
# also only raise if the function is different to prevent raising
|
||||||
|
# if module is reloaded.
|
||||||
|
existing_func = registry.factories.get(internal_name)
|
||||||
|
if not util.is_same_func(factory_func, existing_func):
|
||||||
|
err = Errors.E004.format(
|
||||||
|
name=name, func=existing_func, new_func=factory_func
|
||||||
|
)
|
||||||
|
raise ValueError(err)
|
||||||
|
|
||||||
arg_names = util.get_arg_names(factory_func)
|
arg_names = util.get_arg_names(factory_func)
|
||||||
if "nlp" not in arg_names or "name" not in arg_names:
|
if "nlp" not in arg_names or "name" not in arg_names:
|
||||||
raise ValueError(Errors.E964.format(name=name))
|
raise ValueError(Errors.E964.format(name=name))
|
||||||
|
@ -472,6 +480,21 @@ class Language:
|
||||||
def factory_func(nlp: cls, name: str) -> Callable[[Doc], Doc]:
|
def factory_func(nlp: cls, name: str) -> Callable[[Doc], Doc]:
|
||||||
return component_func
|
return component_func
|
||||||
|
|
||||||
|
internal_name = cls.get_factory_name(name)
|
||||||
|
if internal_name in registry.factories:
|
||||||
|
# We only check for the internal name here – it's okay if it's a
|
||||||
|
# subclass and the base class has a factory of the same name. We
|
||||||
|
# also only raise if the function is different to prevent raising
|
||||||
|
# if module is reloaded. It's hacky, but we need to check the
|
||||||
|
# existing functure for a closure and whether that's identical
|
||||||
|
# to the component function (because factory_func created above
|
||||||
|
# will always be different, even for the same function)
|
||||||
|
existing_func = registry.factories.get(internal_name)
|
||||||
|
closure = existing_func.__closure__
|
||||||
|
wrapped = [c.cell_contents for c in closure][0] if closure else None
|
||||||
|
if util.is_same_func(wrapped, component_func):
|
||||||
|
factory_func = existing_func # noqa: F811
|
||||||
|
|
||||||
cls.factory(
|
cls.factory(
|
||||||
component_name,
|
component_name,
|
||||||
assigns=assigns,
|
assigns=assigns,
|
||||||
|
|
|
@ -438,3 +438,26 @@ def test_pipe_factories_from_source_config():
|
||||||
config = nlp.config["components"]["custom"]
|
config = nlp.config["components"]["custom"]
|
||||||
assert config["factory"] == name
|
assert config["factory"] == name
|
||||||
assert config["arg"] == "world"
|
assert config["arg"] == "world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipe_factories_decorator_idempotent():
|
||||||
|
"""Check that decorator can be run multiple times if the function is the
|
||||||
|
same. This is especially relevant for live reloading because we don't
|
||||||
|
want spaCy to raise an error if a module registering components is reloaded.
|
||||||
|
"""
|
||||||
|
name = "test_pipe_factories_decorator_idempotent"
|
||||||
|
func = lambda nlp, name: lambda doc: doc
|
||||||
|
for i in range(5):
|
||||||
|
Language.factory(name, func=func)
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe(name)
|
||||||
|
Language.factory(name, func=func)
|
||||||
|
# Make sure it also works for component decorator, which creates the
|
||||||
|
# factory function
|
||||||
|
name2 = f"{name}2"
|
||||||
|
func2 = lambda doc: doc
|
||||||
|
for i in range(5):
|
||||||
|
Language.component(name2, func=func2)
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe(name)
|
||||||
|
Language.component(name2, func=func2)
|
||||||
|
|
|
@ -673,6 +673,25 @@ def get_object_name(obj: Any) -> str:
|
||||||
return repr(obj)
|
return repr(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def is_same_func(func1: Callable, func2: Callable) -> bool:
|
||||||
|
"""Approximately decide whether two functions are the same, even if their
|
||||||
|
identity is different (e.g. after they have been live reloaded). Mostly
|
||||||
|
used in the @Language.component and @Language.factory decorators to decide
|
||||||
|
whether to raise if a factory already exists. Allows decorator to run
|
||||||
|
multiple times with the same function.
|
||||||
|
|
||||||
|
func1 (Callable): The first function.
|
||||||
|
func2 (Callable): The second function.
|
||||||
|
RETURNS (bool): Whether it's the same function (most likely).
|
||||||
|
"""
|
||||||
|
if not callable(func1) or not callable(func2):
|
||||||
|
return False
|
||||||
|
same_name = func1.__qualname__ == func2.__qualname__
|
||||||
|
same_file = inspect.getfile(func1) == inspect.getfile(func2)
|
||||||
|
same_code = inspect.getsourcelines(func1) == inspect.getsourcelines(func2)
|
||||||
|
return all([same_name, same_file, same_code])
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_stream(
|
def get_cuda_stream(
|
||||||
require: bool = False, non_blocking: bool = True
|
require: bool = False, non_blocking: bool = True
|
||||||
) -> Optional[CudaStream]:
|
) -> Optional[CudaStream]:
|
||||||
|
|
Loading…
Reference in New Issue