mirror of https://github.com/explosion/spaCy.git
Allow loaded but disabled components
This commit is contained in:
parent
adc050cdc5
commit
3ce5be4b76
|
@ -28,17 +28,22 @@ if sys.maxunicode == 65535:
|
|||
def load(
|
||||
name: Union[str, Path],
|
||||
disable: Iterable[str] = tuple(),
|
||||
exclude: Iterable[str] = tuple(),
|
||||
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
|
||||
) -> Language:
|
||||
"""Load a spaCy model from an installed package or a local path.
|
||||
|
||||
name (str): Package name or model path.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
keyed by section values in dot notation.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
"""
|
||||
return util.load_model(name, disable=disable, config=config)
|
||||
return util.load_model(name, disable=disable, exclude=exclude, config=config)
|
||||
|
||||
|
||||
def blank(name: str, **overrides) -> Language:
|
||||
|
|
|
@ -11,6 +11,7 @@ use_pytorch_for_gpu_memory = false
|
|||
[nlp]
|
||||
lang = null
|
||||
pipeline = []
|
||||
disabled = []
|
||||
load_vocab_data = true
|
||||
before_creation = null
|
||||
after_creation = null
|
||||
|
|
|
@ -6,7 +6,7 @@ import itertools
|
|||
import weakref
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
from thinc.api import get_current_ops, Config, require_gpu, Optimizer
|
||||
|
@ -159,7 +159,8 @@ class Language:
|
|||
self.vocab: Vocab = vocab
|
||||
if self.lang is None:
|
||||
self.lang = self.vocab.lang
|
||||
self.pipeline = []
|
||||
self._pipeline = []
|
||||
self._disabled = set()
|
||||
self.max_length = max_length
|
||||
self.resolved = {}
|
||||
# Create the default tokenizer from the default config
|
||||
|
@ -210,6 +211,7 @@ class Language:
|
|||
# TODO: Adding this back to prevent breaking people's code etc., but
|
||||
# we should consider removing it
|
||||
self._meta["pipeline"] = self.pipe_names
|
||||
self._meta["disabled"] = list(self._disabled)
|
||||
return self._meta
|
||||
|
||||
@meta.setter
|
||||
|
@ -232,13 +234,14 @@ class Language:
|
|||
# we can populate the config again later
|
||||
pipeline = {}
|
||||
score_weights = []
|
||||
for pipe_name in self.pipe_names:
|
||||
for pipe_name in self._pipe_names:
|
||||
pipe_meta = self.get_pipe_meta(pipe_name)
|
||||
pipe_config = self.get_pipe_config(pipe_name)
|
||||
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
|
||||
if pipe_meta.default_score_weights:
|
||||
score_weights.append(pipe_meta.default_score_weights)
|
||||
self._config["nlp"]["pipeline"] = self.pipe_names
|
||||
self._config["nlp"]["pipeline"] = self._pipe_names
|
||||
self._config["nlp"]["disabled"] = list(self._disabled)
|
||||
self._config["components"] = pipeline
|
||||
self._config["training"]["score_weights"] = combine_score_weights(score_weights)
|
||||
if not srsly.is_json_serializable(self._config):
|
||||
|
@ -257,9 +260,30 @@ class Language:
|
|||
"""
|
||||
return list(self.factories.keys())
|
||||
|
||||
@property
|
||||
def _pipe_names(self) -> List[str]:
|
||||
"""Get the names of the available pipeline components. Includes all
|
||||
active and inactive pipeline components.
|
||||
|
||||
RETURNS (List[str]): List of component name strings, in order.
|
||||
"""
|
||||
# TODO: Should we make this available via a user-facing property? (The
|
||||
# underscore distinction works well internally)
|
||||
return [pipe_name for pipe_name, _ in self._pipeline]
|
||||
|
||||
@property
|
||||
def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
|
||||
"""The processing pipeline consisting of (name, component) tuples. The
|
||||
components are called on the Doc in order as it passes through the
|
||||
pipeline.
|
||||
|
||||
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
|
||||
"""
|
||||
return [(name, p) for name, p in self._pipeline if name not in self._disabled]
|
||||
|
||||
@property
|
||||
def pipe_names(self) -> List[str]:
|
||||
"""Get names of available pipeline components.
|
||||
"""Get names of available active pipeline components.
|
||||
|
||||
RETURNS (List[str]): List of component name strings, in order.
|
||||
"""
|
||||
|
@ -272,7 +296,7 @@ class Language:
|
|||
RETURNS (Dict[str, str]): Factory names, keyed by component names.
|
||||
"""
|
||||
factories = {}
|
||||
for pipe_name, pipe in self.pipeline:
|
||||
for pipe_name, pipe in self._pipeline:
|
||||
factories[pipe_name] = self.get_pipe_meta(pipe_name).factory
|
||||
return factories
|
||||
|
||||
|
@ -284,7 +308,7 @@ class Language:
|
|||
RETURNS (Dict[str, List[str]]): Labels keyed by component name.
|
||||
"""
|
||||
labels = {}
|
||||
for name, pipe in self.pipeline:
|
||||
for name, pipe in self._pipeline:
|
||||
if hasattr(pipe, "labels"):
|
||||
labels[name] = list(pipe.labels)
|
||||
return labels
|
||||
|
@ -512,10 +536,10 @@ class Language:
|
|||
|
||||
DOCS: https://spacy.io/api/language#get_pipe
|
||||
"""
|
||||
for pipe_name, component in self.pipeline:
|
||||
for pipe_name, component in self._pipeline:
|
||||
if pipe_name == name:
|
||||
return component
|
||||
raise KeyError(Errors.E001.format(name=name, opts=self.pipe_names))
|
||||
raise KeyError(Errors.E001.format(name=name, opts=self._pipe_names))
|
||||
|
||||
def create_pipe(
|
||||
self,
|
||||
|
@ -660,8 +684,8 @@ class Language:
|
|||
err = Errors.E966.format(component=bad_val, name=name)
|
||||
raise ValueError(err)
|
||||
name = name if name is not None else factory_name
|
||||
if name in self.pipe_names:
|
||||
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names))
|
||||
if name in self._pipe_names:
|
||||
raise ValueError(Errors.E007.format(name=name, opts=self._pipe_names))
|
||||
if source is not None:
|
||||
# We're loading the component from a model. After loading the
|
||||
# component, we know its real factory name
|
||||
|
@ -686,7 +710,7 @@ class Language:
|
|||
)
|
||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
||||
self.pipeline.insert(pipe_index, (name, pipe_component))
|
||||
self._pipeline.insert(pipe_index, (name, pipe_component))
|
||||
return pipe_component
|
||||
|
||||
def _get_pipe_index(
|
||||
|
@ -707,32 +731,34 @@ class Language:
|
|||
"""
|
||||
all_args = {"before": before, "after": after, "first": first, "last": last}
|
||||
if sum(arg is not None for arg in [before, after, first, last]) >= 2:
|
||||
raise ValueError(Errors.E006.format(args=all_args, opts=self.pipe_names))
|
||||
raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names))
|
||||
if last or not any(value is not None for value in [first, before, after]):
|
||||
return len(self.pipeline)
|
||||
return len(self._pipeline)
|
||||
elif first:
|
||||
return 0
|
||||
elif isinstance(before, str):
|
||||
if before not in self.pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=before, opts=self.pipe_names))
|
||||
return self.pipe_names.index(before)
|
||||
if before not in self._pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=before, opts=self._pipe_names))
|
||||
return self._pipe_names.index(before)
|
||||
elif isinstance(after, str):
|
||||
if after not in self.pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=after, opts=self.pipe_names))
|
||||
return self.pipe_names.index(after) + 1
|
||||
if after not in self._pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=after, opts=self._pipe_names))
|
||||
return self._pipe_names.index(after) + 1
|
||||
# We're only accepting indices referring to components that exist
|
||||
# (can't just do isinstance here because bools are instance of int, too)
|
||||
elif type(before) == int:
|
||||
if before >= len(self.pipeline) or before < 0:
|
||||
err = Errors.E959.format(dir="before", idx=before, opts=self.pipe_names)
|
||||
if before >= len(self._pipeline) or before < 0:
|
||||
err = Errors.E959.format(
|
||||
dir="before", idx=before, opts=self._pipe_names
|
||||
)
|
||||
raise ValueError(err)
|
||||
return before
|
||||
elif type(after) == int:
|
||||
if after >= len(self.pipeline) or after < 0:
|
||||
err = Errors.E959.format(dir="after", idx=after, opts=self.pipe_names)
|
||||
if after >= len(self._pipeline) or after < 0:
|
||||
err = Errors.E959.format(dir="after", idx=after, opts=self._pipe_names)
|
||||
raise ValueError(err)
|
||||
return after + 1
|
||||
raise ValueError(Errors.E006.format(args=all_args, opts=self.pipe_names))
|
||||
raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names))
|
||||
|
||||
def has_pipe(self, name: str) -> bool:
|
||||
"""Check if a component name is present in the pipeline. Equivalent to
|
||||
|
@ -773,7 +799,7 @@ class Language:
|
|||
# to Language.pipeline to make sure the configs are handled correctly
|
||||
pipe_index = self.pipe_names.index(name)
|
||||
self.remove_pipe(name)
|
||||
if not len(self.pipeline) or pipe_index == len(self.pipeline):
|
||||
if not len(self._pipeline) or pipe_index == len(self._pipeline):
|
||||
# we have no components to insert before/after, or we're replacing the last component
|
||||
self.add_pipe(factory_name, name=name, config=config, validate=validate)
|
||||
else:
|
||||
|
@ -793,12 +819,12 @@ class Language:
|
|||
|
||||
DOCS: https://spacy.io/api/language#rename_pipe
|
||||
"""
|
||||
if old_name not in self.pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=old_name, opts=self.pipe_names))
|
||||
if new_name in self.pipe_names:
|
||||
raise ValueError(Errors.E007.format(name=new_name, opts=self.pipe_names))
|
||||
i = self.pipe_names.index(old_name)
|
||||
self.pipeline[i] = (new_name, self.pipeline[i][1])
|
||||
if old_name not in self._pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=old_name, opts=self._pipe_names))
|
||||
if new_name in self._pipe_names:
|
||||
raise ValueError(Errors.E007.format(name=new_name, opts=self._pipe_names))
|
||||
i = self._pipe_names.index(old_name)
|
||||
self._pipeline[i] = (new_name, self._pipeline[i][1])
|
||||
self._pipe_meta[new_name] = self._pipe_meta.pop(old_name)
|
||||
self._pipe_configs[new_name] = self._pipe_configs.pop(old_name)
|
||||
|
||||
|
@ -810,15 +836,41 @@ class Language:
|
|||
|
||||
DOCS: https://spacy.io/api/language#remove_pipe
|
||||
"""
|
||||
if name not in self.pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
|
||||
removed = self.pipeline.pop(self.pipe_names.index(name))
|
||||
if name not in self._pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
|
||||
removed = self._pipeline.pop(self._pipe_names.index(name))
|
||||
# We're only removing the component itself from the metas/configs here
|
||||
# because factory may be used for something else
|
||||
self._pipe_meta.pop(name)
|
||||
self._pipe_configs.pop(name)
|
||||
# Make sure the name is also removed from the set of disabled components
|
||||
if name in self._disabled:
|
||||
self._disabled.remove(name)
|
||||
return removed
|
||||
|
||||
def disable_pipe(self, name: str) -> None:
|
||||
"""Disable a pipeline component. The component will still exist on
|
||||
the nlp object, but it won't be run as part of the pipeline.
|
||||
|
||||
name (str): The name of the component to disable.
|
||||
"""
|
||||
if name not in self._pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
|
||||
# TODO: should we raise if pipe is already disabled?
|
||||
self._disabled.add(name)
|
||||
|
||||
def enable_pipe(self, name: str) -> None:
|
||||
"""Enable a previously disabled pipeline component so it's run as part
|
||||
of the pipeline.
|
||||
|
||||
name (str): The name of the component to enable.
|
||||
"""
|
||||
if name not in self._pipe_names:
|
||||
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
|
||||
# TODO: should we raise if pipe is already enabled?
|
||||
if name in self._disabled:
|
||||
self._disabled.remove(name)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str,
|
||||
|
@ -1366,6 +1418,7 @@ class Language:
|
|||
*,
|
||||
vocab: Union[Vocab, bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
exclude: Iterable[str] = tuple(),
|
||||
auto_fill: bool = True,
|
||||
validate: bool = True,
|
||||
) -> "Language":
|
||||
|
@ -1375,7 +1428,11 @@ class Language:
|
|||
|
||||
config (Dict[str, Any] / Config): The loaded config.
|
||||
vocab (Vocab): A Vocab object. If True, a vocab is created.
|
||||
disable (Iterable[str]): List of pipeline component names to disable.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
Disabled pipes will be loaded but they won't be run unless you
|
||||
explicitly enable them by calling nlp.enable_pipe.
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude.
|
||||
Excluded components won't be loaded.
|
||||
auto_fill (bool): Automatically fill in missing values in config based
|
||||
on defaults and function argument annotations.
|
||||
validate (bool): Validate the component config and arguments against
|
||||
|
@ -1448,7 +1505,7 @@ class Language:
|
|||
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
|
||||
pipe_cfg = util.copy_config(pipeline[pipe_name])
|
||||
raw_config = Config(filled["components"][pipe_name])
|
||||
if pipe_name not in disable:
|
||||
if pipe_name not in exclude:
|
||||
if "factory" not in pipe_cfg and "source" not in pipe_cfg:
|
||||
err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
|
||||
raise ValueError(err)
|
||||
|
@ -1473,6 +1530,8 @@ class Language:
|
|||
)
|
||||
source_name = pipe_cfg.get("component", pipe_name)
|
||||
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
|
||||
disabled_pipes = [*config["nlp"]["disabled"], *disable]
|
||||
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
|
||||
nlp.config = filled if auto_fill else config
|
||||
nlp.resolved = resolved
|
||||
if after_pipeline_creation is not None:
|
||||
|
@ -1502,9 +1561,10 @@ class Language:
|
|||
)
|
||||
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
|
||||
serializers["config.cfg"] = lambda p: self.config.to_disk(p)
|
||||
for name, proc in self.pipeline:
|
||||
if not hasattr(proc, "name"):
|
||||
continue
|
||||
for name, proc in self._pipeline:
|
||||
# TODO: why did we add this?
|
||||
# if not hasattr(proc, "name"):
|
||||
# continue
|
||||
if name in exclude:
|
||||
continue
|
||||
if not hasattr(proc, "to_disk"):
|
||||
|
@ -1548,7 +1608,7 @@ class Language:
|
|||
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
|
||||
p, exclude=["vocab"]
|
||||
)
|
||||
for name, proc in self.pipeline:
|
||||
for name, proc in self._pipeline:
|
||||
if name in exclude:
|
||||
continue
|
||||
if not hasattr(proc, "from_disk"):
|
||||
|
@ -1577,7 +1637,7 @@ class Language:
|
|||
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
|
||||
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
|
||||
serializers["config.cfg"] = lambda: self.config.to_bytes()
|
||||
for name, proc in self.pipeline:
|
||||
for name, proc in self._pipeline:
|
||||
if name in exclude:
|
||||
continue
|
||||
if not hasattr(proc, "to_bytes"):
|
||||
|
@ -1611,7 +1671,7 @@ class Language:
|
|||
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
|
||||
b, exclude=["vocab"]
|
||||
)
|
||||
for name, proc in self.pipeline:
|
||||
for name, proc in self._pipeline:
|
||||
if name in exclude:
|
||||
continue
|
||||
if not hasattr(proc, "from_bytes"):
|
||||
|
@ -1647,14 +1707,10 @@ class DisabledPipes(list):
|
|||
def __init__(self, nlp: Language, names: List[str]) -> None:
|
||||
self.nlp = nlp
|
||||
self.names = names
|
||||
# Important! Not deep copy -- we just want the container (but we also
|
||||
# want to support people providing arbitrarily typed nlp.pipeline
|
||||
# objects.)
|
||||
self.original_pipeline = copy(nlp.pipeline)
|
||||
self.metas = {name: nlp.get_pipe_meta(name) for name in names}
|
||||
self.configs = {name: nlp.get_pipe_config(name) for name in names}
|
||||
for name in self.names:
|
||||
self.nlp.disable_pipe(name)
|
||||
list.__init__(self)
|
||||
self.extend(nlp.remove_pipe(name) for name in names)
|
||||
self.extend(self.names)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
@ -1664,14 +1720,10 @@ class DisabledPipes(list):
|
|||
|
||||
def restore(self) -> None:
|
||||
"""Restore the pipeline to its state when DisabledPipes was created."""
|
||||
current, self.nlp.pipeline = self.nlp.pipeline, self.original_pipeline
|
||||
unexpected = [name for name, pipe in current if not self.nlp.has_pipe(name)]
|
||||
if unexpected:
|
||||
# Don't change the pipeline if we're raising an error.
|
||||
self.nlp.pipeline = current
|
||||
raise ValueError(Errors.E008.format(names=unexpected))
|
||||
self.nlp._pipe_meta.update(self.metas)
|
||||
self.nlp._pipe_configs.update(self.configs)
|
||||
for name in self.names:
|
||||
self.nlp.enable_pipe(name)
|
||||
# TODO: maybe add some more checks / catch errors that may occur if
|
||||
# user removes a disabled pipe in the with block
|
||||
self[:] = []
|
||||
|
||||
|
||||
|
|
|
@ -223,6 +223,7 @@ class ConfigSchemaNlp(BaseModel):
|
|||
# fmt: off
|
||||
lang: StrictStr = Field(..., title="The base language to use")
|
||||
pipeline: List[StrictStr] = Field(..., title="The pipeline component names in order")
|
||||
disabled: List[StrictStr] = Field(..., title="Pipeline components to disable by default")
|
||||
tokenizer: Callable = Field(..., title="The tokenizer to use")
|
||||
load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
|
||||
before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization")
|
||||
|
|
|
@ -249,3 +249,66 @@ def test_add_pipe_before_after():
|
|||
nlp.add_pipe("entity_ruler", before=True)
|
||||
with pytest.raises(ValueError):
|
||||
nlp.add_pipe("entity_ruler", first=False)
|
||||
|
||||
|
||||
def test_disable_enable_pipes():
|
||||
name = "test_disable_enable_pipes"
|
||||
results = {}
|
||||
|
||||
def make_component(name):
|
||||
results[name] = ""
|
||||
|
||||
def component(doc):
|
||||
nonlocal results
|
||||
results[name] = doc.text
|
||||
return doc
|
||||
|
||||
return component
|
||||
|
||||
c1 = Language.component(f"{name}1", func=make_component(f"{name}1"))
|
||||
c2 = Language.component(f"{name}2", func=make_component(f"{name}2"))
|
||||
|
||||
nlp = Language()
|
||||
nlp.add_pipe(f"{name}1")
|
||||
nlp.add_pipe(f"{name}2")
|
||||
assert results[f"{name}1"] == ""
|
||||
assert results[f"{name}2"] == ""
|
||||
assert nlp.pipeline == [(f"{name}1", c1), (f"{name}2", c2)]
|
||||
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
|
||||
nlp.disable_pipe(f"{name}1")
|
||||
assert nlp._disabled == set([f"{name}1"])
|
||||
assert nlp._pipe_names == [f"{name}1", f"{name}2"]
|
||||
assert nlp.pipe_names == [f"{name}2"]
|
||||
assert nlp.config["nlp"]["disabled"] == [f"{name}1"]
|
||||
nlp("hello")
|
||||
assert results[f"{name}1"] == "" # didn't run
|
||||
assert results[f"{name}2"] == "hello" # ran
|
||||
nlp.enable_pipe(f"{name}1")
|
||||
assert nlp._disabled == set()
|
||||
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
|
||||
assert nlp.config["nlp"]["disabled"] == []
|
||||
nlp("world")
|
||||
assert results[f"{name}1"] == "world"
|
||||
assert results[f"{name}2"] == "world"
|
||||
nlp.disable_pipe(f"{name}2")
|
||||
nlp.remove_pipe(f"{name}2")
|
||||
assert nlp._pipeline == [(f"{name}1", c1)]
|
||||
assert nlp.pipeline == [(f"{name}1", c1)]
|
||||
assert nlp._pipe_names == [f"{name}1"]
|
||||
assert nlp.pipe_names == [f"{name}1"]
|
||||
assert nlp._disabled == set()
|
||||
assert nlp.config["nlp"]["disabled"] == []
|
||||
nlp.rename_pipe(f"{name}1", name)
|
||||
assert nlp._pipeline == [(name, c1)]
|
||||
assert nlp._pipe_names == [name]
|
||||
nlp("!")
|
||||
assert results[f"{name}1"] == "!"
|
||||
assert results[f"{name}2"] == "world"
|
||||
with pytest.raises(ValueError):
|
||||
nlp.disable_pipe(f"{name}2")
|
||||
nlp.disable_pipe(name)
|
||||
assert nlp._pipe_names == [name]
|
||||
assert nlp.pipe_names == []
|
||||
assert nlp.config["nlp"]["disabled"] == [name]
|
||||
nlp("?")
|
||||
assert results[f"{name}1"] == "!"
|
||||
|
|
|
@ -161,6 +161,7 @@ def test_issue4674():
|
|||
assert kb2.get_size_entities() == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="API change: disable just disables, new exclude arg")
|
||||
def test_issue4707():
|
||||
"""Tests that disabled component names are also excluded from nlp.from_disk
|
||||
by default when loading a model.
|
||||
|
|
|
@ -6,6 +6,8 @@ from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
|||
from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL
|
||||
from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
|
||||
from spacy.pipeline.senter import DEFAULT_SENTER_MODEL
|
||||
from spacy.lang.en import English
|
||||
import spacy
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -173,3 +175,34 @@ def test_serialize_sentencerecognizer(en_vocab):
|
|||
sr_b = sr.to_bytes()
|
||||
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)
|
||||
assert sr.to_bytes() == sr_d.to_bytes()
|
||||
|
||||
|
||||
def test_serialize_pipeline_disable_enable():
|
||||
nlp = English()
|
||||
nlp.add_pipe("ner")
|
||||
nlp.add_pipe("tagger")
|
||||
nlp.disable_pipe("tagger")
|
||||
assert nlp.config["nlp"]["disabled"] == ["tagger"]
|
||||
config = nlp.config.copy()
|
||||
nlp2 = English.from_config(config)
|
||||
assert nlp2.pipe_names == ["ner"]
|
||||
assert nlp2._pipe_names == ["ner", "tagger"]
|
||||
assert nlp2._disabled == set(["tagger"])
|
||||
assert nlp2.config["nlp"]["disabled"] == ["tagger"]
|
||||
with make_tempdir() as d:
|
||||
nlp2.to_disk(d)
|
||||
nlp3 = spacy.load(d)
|
||||
assert nlp3.pipe_names == ["ner"]
|
||||
assert nlp3._pipe_names == ["ner", "tagger"]
|
||||
with make_tempdir() as d:
|
||||
nlp3.to_disk(d)
|
||||
nlp4 = spacy.load(d, disable=["ner"])
|
||||
assert nlp4.pipe_names == []
|
||||
assert nlp4._pipe_names == ["ner", "tagger"]
|
||||
assert nlp4._disabled == set(["ner", "tagger"])
|
||||
with make_tempdir() as d:
|
||||
nlp.to_disk(d)
|
||||
nlp5 = spacy.load(d, exclude=["tagger"])
|
||||
assert nlp5.pipe_names == ["ner"]
|
||||
assert nlp5._pipe_names == ["ner"]
|
||||
assert nlp5._disabled == set()
|
||||
|
|
|
@ -216,6 +216,7 @@ def load_model(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
exclude: Iterable[str] = tuple(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Load a model from a package or data path.
|
||||
|
@ -228,7 +229,7 @@ def load_model(
|
|||
keyed by section values in dot notation.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
"""
|
||||
kwargs = {"vocab": vocab, "disable": disable, "config": config}
|
||||
kwargs = {"vocab": vocab, "disable": disable, "exclude": exclude, "config": config}
|
||||
if isinstance(name, str): # name or string path
|
||||
if name.startswith("blank:"): # shortcut for blank model
|
||||
return get_lang_class(name.replace("blank:", ""))()
|
||||
|
@ -248,6 +249,7 @@ def load_model_from_package(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
exclude: Iterable[str] = tuple(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Load a model from an installed package.
|
||||
|
@ -255,13 +257,17 @@ def load_model_from_package(
|
|||
name (str): The package name.
|
||||
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
||||
a new Vocab object will be created.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
keyed by section values in dot notation.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
"""
|
||||
cls = importlib.import_module(name)
|
||||
return cls.load(vocab=vocab, disable=disable, config=config)
|
||||
return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config)
|
||||
|
||||
|
||||
def load_model_from_path(
|
||||
|
@ -270,6 +276,7 @@ def load_model_from_path(
|
|||
meta: Optional[Dict[str, Any]] = None,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
exclude: Iterable[str] = tuple(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Load a model from a data directory path. Creates Language class with
|
||||
|
@ -279,7 +286,11 @@ def load_model_from_path(
|
|||
meta (Dict[str, Any]): Optional model meta.
|
||||
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
||||
a new Vocab object will be created.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
keyed by section values in dot notation.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
|
@ -290,8 +301,10 @@ def load_model_from_path(
|
|||
meta = get_model_meta(model_path)
|
||||
config_path = model_path / "config.cfg"
|
||||
config = load_config(config_path, overrides=dict_to_dot(config))
|
||||
nlp, _ = load_model_from_config(config, vocab=vocab, disable=disable)
|
||||
return nlp.from_disk(model_path, exclude=disable)
|
||||
nlp, _ = load_model_from_config(
|
||||
config, vocab=vocab, disable=disable, exclude=exclude
|
||||
)
|
||||
return nlp.from_disk(model_path, exclude=exclude)
|
||||
|
||||
|
||||
def load_model_from_config(
|
||||
|
@ -299,6 +312,7 @@ def load_model_from_config(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
exclude: Iterable[str] = tuple(),
|
||||
auto_fill: bool = False,
|
||||
validate: bool = True,
|
||||
) -> Tuple["Language", Config]:
|
||||
|
@ -309,7 +323,11 @@ def load_model_from_config(
|
|||
meta (Dict[str, Any]): Optional model meta.
|
||||
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
||||
a new Vocab object will be created.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
auto_fill (bool): Whether to auto-fill config with missing defaults.
|
||||
validate (bool): Whether to show config validation errors.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
|
@ -323,7 +341,12 @@ def load_model_from_config(
|
|||
# registry, including custom subclasses provided via entry points
|
||||
lang_cls = get_lang_class(nlp_config["lang"])
|
||||
nlp = lang_cls.from_config(
|
||||
config, vocab=vocab, disable=disable, auto_fill=auto_fill, validate=validate,
|
||||
config,
|
||||
vocab=vocab,
|
||||
disable=disable,
|
||||
exclude=exclude,
|
||||
auto_fill=auto_fill,
|
||||
validate=validate,
|
||||
)
|
||||
return nlp, nlp.resolved
|
||||
|
||||
|
@ -333,6 +356,7 @@ def load_model_from_init_py(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
exclude: Iterable[str] = tuple(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Helper function to use in the `load()` method of a model package's
|
||||
|
@ -340,7 +364,11 @@ def load_model_from_init_py(
|
|||
|
||||
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
||||
a new Vocab object will be created.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
keyed by section values in dot notation.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
|
@ -352,7 +380,12 @@ def load_model_from_init_py(
|
|||
if not model_path.exists():
|
||||
raise IOError(Errors.E052.format(path=data_path))
|
||||
return load_model_from_path(
|
||||
data_path, vocab=vocab, meta=meta, disable=disable, config=config
|
||||
data_path,
|
||||
vocab=vocab,
|
||||
meta=meta,
|
||||
disable=disable,
|
||||
exclude=exclude,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue