diff --git a/spacy/__init__.py b/spacy/__init__.py index 8e9c8db69..5c286ed80 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -27,8 +27,8 @@ if sys.maxunicode == 65535: def load( name: Union[str, Path], - disable: Iterable[str] = tuple(), - exclude: Iterable[str] = tuple(), + disable: Iterable[str] = util.SimpleFrozenList(), + exclude: Iterable[str] = util.SimpleFrozenList(), config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(), ) -> Language: """Load a spaCy model from an installed package or a local path. diff --git a/spacy/cli/project/dvc.py b/spacy/cli/project/dvc.py index e0f6cd430..de0480bad 100644 --- a/spacy/cli/project/dvc.py +++ b/spacy/cli/project/dvc.py @@ -1,6 +1,6 @@ """This module contains helpers and subcommands for integrating spaCy projects with Data Version Controk (DVC). https://dvc.org""" -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Iterable import subprocess from pathlib import Path from wasabi import msg @@ -8,6 +8,7 @@ from wasabi import msg from .._util import PROJECT_FILE, load_project_config, get_hash, project_cli from .._util import Arg, Opt, NAME, COMMAND from ...util import working_dir, split_command, join_command, run_command +from ...util import SimpleFrozenList DVC_CONFIG = "dvc.yaml" @@ -130,7 +131,7 @@ def update_dvc_config( def run_dvc_commands( - commands: List[str] = tuple(), flags: Dict[str, bool] = {}, + commands: Iterable[str] = SimpleFrozenList(), flags: Dict[str, bool] = {}, ) -> None: """Run a sequence of DVC commands in a subprocess, in order. diff --git a/spacy/cli/project/run.py b/spacy/cli/project/run.py index 7b579314b..bacd7f04b 100644 --- a/spacy/cli/project/run.py +++ b/spacy/cli/project/run.py @@ -1,10 +1,11 @@ -from typing import Optional, List, Dict, Sequence, Any +from typing import Optional, List, Dict, Sequence, Any, Iterable from pathlib import Path from wasabi import msg import sys import srsly from ...util import working_dir, run_command, split_command, is_cwd, join_command +from ...util import SimpleFrozenList from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash from .._util import get_checksum, project_cli, Arg, Opt, COMMAND @@ -115,7 +116,9 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None: def run_commands( - commands: List[str] = tuple(), silent: bool = False, dry: bool = False, + commands: Iterable[str] = SimpleFrozenList(), + silent: bool = False, + dry: bool = False, ) -> None: """Run a sequence of commands in a subprocess, in order. diff --git a/spacy/errors.py b/spacy/errors.py index a66861667..75c41065a 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -472,6 +472,13 @@ class Errors: E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") # TODO: fix numbering after merging develop into master + E926 = ("It looks like you're trying to modify nlp.{attr} directly. This " + "doesn't work because it's an immutable computed property. If you " + "need to modify the pipeline, use the built-in methods like " + "nlp.add_pipe, nlp.remove_pipe, nlp.disable_pipe or nlp.enable_pipe " + "instead.") + E927 = ("Can't write to frozen list Maybe you're trying to modify a computed " + "property or default function argument?") E928 = ("A 'KnowledgeBase' should be written to / read from a file, but the " "provided argument {loc} is an existing directory.") E929 = ("A 'KnowledgeBase' could not be read from {loc} - the path does " diff --git a/spacy/language.py b/spacy/language.py index 140452a77..a8d75cfae 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -20,7 +20,7 @@ from .vocab import Vocab, create_vocab from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis from .gold import Example, validate_examples from .scorer import Scorer -from .util import create_default_optimizer, registry +from .util import create_default_optimizer, registry, SimpleFrozenList from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES @@ -159,7 +159,7 @@ class Language: self.vocab: Vocab = vocab if self.lang is None: self.lang = self.vocab.lang - self.components = [] + self._components = [] self._disabled = set() self.max_length = max_length self.resolved = {} @@ -207,11 +207,11 @@ class Language: "keys": self.vocab.vectors.n_keys, "name": self.vocab.vectors.name, } - self._meta["labels"] = self.pipe_labels + self._meta["labels"] = dict(self.pipe_labels) # TODO: Adding this back to prevent breaking people's code etc., but # we should consider removing it - self._meta["pipeline"] = self.pipe_names - self._meta["disabled"] = self.disabled + self._meta["pipeline"] = list(self.pipe_names) + self._meta["disabled"] = list(self.disabled) return self._meta @meta.setter @@ -240,8 +240,8 @@ class Language: pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config} if pipe_meta.default_score_weights: score_weights.append(pipe_meta.default_score_weights) - self._config["nlp"]["pipeline"] = self.component_names - self._config["nlp"]["disabled"] = self.disabled + self._config["nlp"]["pipeline"] = list(self.component_names) + self._config["nlp"]["disabled"] = list(self.disabled) self._config["components"] = pipeline self._config["training"]["score_weights"] = combine_score_weights(score_weights) if not srsly.is_json_serializable(self._config): @@ -260,7 +260,8 @@ class Language: """ # Make sure the disabled components are returned in the order they # appear in the pipeline (which isn't guaranteed by the set) - return [name for name, _ in self.components if name in self._disabled] + names = [name for name, _ in self._components if name in self._disabled] + return SimpleFrozenList(names, error=Errors.E926.format(attr="disabled")) @property def factory_names(self) -> List[str]: @@ -268,7 +269,17 @@ class Language: RETURNS (List[str]): The factory names. """ - return list(self.factories.keys()) + names = list(self.factories.keys()) + return SimpleFrozenList(names) + + @property + def components(self) -> List[Tuple[str, Callable[[Doc], Doc]]]: + """Get all (name, component) tuples in the pipeline, including the + currently disabled components. + """ + return SimpleFrozenList( + self._components, error=Errors.E926.format(attr="components") + ) @property def component_names(self) -> List[str]: @@ -277,7 +288,8 @@ class Language: RETURNS (List[str]): List of component name strings, in order. """ - return [pipe_name for pipe_name, _ in self.components] + names = [pipe_name for pipe_name, _ in self._components] + return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names")) @property def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]: @@ -287,7 +299,8 @@ class Language: RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline. """ - return [(name, p) for name, p in self.components if name not in self._disabled] + pipes = [(n, p) for n, p in self._components if n not in self._disabled] + return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline")) @property def pipe_names(self) -> List[str]: @@ -295,7 +308,8 @@ class Language: RETURNS (List[str]): List of component name strings, in order. """ - return [pipe_name for pipe_name, _ in self.pipeline] + names = [pipe_name for pipe_name, _ in self.pipeline] + return SimpleFrozenList(names, error=Errors.E926.format(attr="pipe_names")) @property def pipe_factories(self) -> Dict[str, str]: @@ -304,9 +318,9 @@ class Language: RETURNS (Dict[str, str]): Factory names, keyed by component names. """ factories = {} - for pipe_name, pipe in self.components: + for pipe_name, pipe in self._components: factories[pipe_name] = self.get_pipe_meta(pipe_name).factory - return factories + return SimpleFrozenDict(factories) @property def pipe_labels(self) -> Dict[str, List[str]]: @@ -316,10 +330,10 @@ class Language: RETURNS (Dict[str, List[str]]): Labels keyed by component name. """ labels = {} - for name, pipe in self.components: + for name, pipe in self._components: if hasattr(pipe, "labels"): labels[name] = list(pipe.labels) - return labels + return SimpleFrozenDict(labels) @classmethod def has_factory(cls, name: str) -> bool: @@ -390,10 +404,10 @@ class Language: name: str, *, default_config: Dict[str, Any] = SimpleFrozenDict(), - assigns: Iterable[str] = tuple(), - requires: Iterable[str] = tuple(), + assigns: Iterable[str] = SimpleFrozenList(), + requires: Iterable[str] = SimpleFrozenList(), retokenizes: bool = False, - scores: Iterable[str] = tuple(), + scores: Iterable[str] = SimpleFrozenList(), default_score_weights: Dict[str, float] = SimpleFrozenDict(), func: Optional[Callable] = None, ) -> Callable: @@ -471,8 +485,8 @@ class Language: cls, name: Optional[str] = None, *, - assigns: Iterable[str] = tuple(), - requires: Iterable[str] = tuple(), + assigns: Iterable[str] = SimpleFrozenList(), + requires: Iterable[str] = SimpleFrozenList(), retokenizes: bool = False, func: Optional[Callable[[Doc], Doc]] = None, ) -> Callable: @@ -544,7 +558,7 @@ class Language: DOCS: https://spacy.io/api/language#get_pipe """ - for pipe_name, component in self.components: + for pipe_name, component in self._components: if pipe_name == name: return component raise KeyError(Errors.E001.format(name=name, opts=self.component_names)) @@ -718,7 +732,7 @@ class Language: ) pipe_index = self._get_pipe_index(before, after, first, last) self._pipe_meta[name] = self.get_factory_meta(factory_name) - self.components.insert(pipe_index, (name, pipe_component)) + self._components.insert(pipe_index, (name, pipe_component)) return pipe_component def _get_pipe_index( @@ -743,7 +757,7 @@ class Language: Errors.E006.format(args=all_args, opts=self.component_names) ) if last or not any(value is not None for value in [first, before, after]): - return len(self.components) + return len(self._components) elif first: return 0 elif isinstance(before, str): @@ -761,14 +775,14 @@ class Language: # We're only accepting indices referring to components that exist # (can't just do isinstance here because bools are instance of int, too) elif type(before) == int: - if before >= len(self.components) or before < 0: + if before >= len(self._components) or before < 0: err = Errors.E959.format( dir="before", idx=before, opts=self.component_names ) raise ValueError(err) return before elif type(after) == int: - if after >= len(self.components) or after < 0: + if after >= len(self._components) or after < 0: err = Errors.E959.format( dir="after", idx=after, opts=self.component_names ) @@ -815,7 +829,7 @@ class Language: # to Language.pipeline to make sure the configs are handled correctly pipe_index = self.pipe_names.index(name) self.remove_pipe(name) - if not len(self.components) or pipe_index == len(self.components): + if not len(self._components) or pipe_index == len(self._components): # we have no components to insert before/after, or we're replacing the last component self.add_pipe(factory_name, name=name, config=config, validate=validate) else: @@ -844,7 +858,7 @@ class Language: Errors.E007.format(name=new_name, opts=self.component_names) ) i = self.component_names.index(old_name) - self.components[i] = (new_name, self.components[i][1]) + self._components[i] = (new_name, self._components[i][1]) self._pipe_meta[new_name] = self._pipe_meta.pop(old_name) self._pipe_configs[new_name] = self._pipe_configs.pop(old_name) @@ -858,7 +872,7 @@ class Language: """ if name not in self.component_names: raise ValueError(Errors.E001.format(name=name, opts=self.component_names)) - removed = self.components.pop(self.component_names.index(name)) + removed = self._components.pop(self.component_names.index(name)) # We're only removing the component itself from the metas/configs here # because factory may be used for something else self._pipe_meta.pop(name) @@ -894,7 +908,7 @@ class Language: self, text: str, *, - disable: Iterable[str] = tuple(), + disable: Iterable[str] = SimpleFrozenList(), component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, ) -> Doc: """Apply the pipeline to some text. The text can span multiple sentences, @@ -993,7 +1007,7 @@ class Language: sgd: Optional[Optimizer] = None, losses: Optional[Dict[str, float]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, - exclude: Iterable[str] = tuple(), + exclude: Iterable[str] = SimpleFrozenList(), ): """Update the models in the pipeline. @@ -1047,7 +1061,7 @@ class Language: sgd: Optional[Optimizer] = None, losses: Optional[Dict[str, float]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, - exclude: Iterable[str] = tuple(), + exclude: Iterable[str] = SimpleFrozenList(), ) -> Dict[str, float]: """Make a "rehearsal" update to the models in the pipeline, to prevent forgetting. Rehearsal updates run an initial copy of the model over some @@ -1276,7 +1290,7 @@ class Language: *, as_tuples: bool = False, batch_size: int = 1000, - disable: Iterable[str] = tuple(), + disable: Iterable[str] = SimpleFrozenList(), cleanup: bool = False, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, n_process: int = 1, @@ -1436,8 +1450,8 @@ class Language: config: Union[Dict[str, Any], Config] = {}, *, vocab: Union[Vocab, bool] = True, - disable: Iterable[str] = tuple(), - exclude: Iterable[str] = tuple(), + disable: Iterable[str] = SimpleFrozenList(), + exclude: Iterable[str] = SimpleFrozenList(), auto_fill: bool = True, validate: bool = True, ) -> "Language": @@ -1562,7 +1576,7 @@ class Language: return nlp def to_disk( - self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() + self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() ) -> None: """Save the current state to a directory. If a model is loaded, this will include the model. @@ -1580,7 +1594,7 @@ class Language: ) serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta) serializers["config.cfg"] = lambda p: self.config.to_disk(p) - for name, proc in self.components: + for name, proc in self._components: if name in exclude: continue if not hasattr(proc, "to_disk"): @@ -1590,7 +1604,7 @@ class Language: util.to_disk(path, serializers, exclude) def from_disk( - self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() + self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() ) -> "Language": """Loads state from a directory. Modifies the object in place and returns it. If the saved `Language` object contains a model, the @@ -1624,7 +1638,7 @@ class Language: deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk( p, exclude=["vocab"] ) - for name, proc in self.components: + for name, proc in self._components: if name in exclude: continue if not hasattr(proc, "from_disk"): @@ -1640,7 +1654,7 @@ class Language: self._link_components() return self - def to_bytes(self, *, exclude: Iterable[str] = tuple()) -> bytes: + def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes: """Serialize the current state to a binary string. exclude (list): Names of components or serialization fields to exclude. @@ -1653,7 +1667,7 @@ class Language: serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) serializers["config.cfg"] = lambda: self.config.to_bytes() - for name, proc in self.components: + for name, proc in self._components: if name in exclude: continue if not hasattr(proc, "to_bytes"): @@ -1662,7 +1676,7 @@ class Language: return util.to_bytes(serializers, exclude) def from_bytes( - self, bytes_data: bytes, *, exclude: Iterable[str] = tuple() + self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList() ) -> "Language": """Load state from a binary string. @@ -1687,7 +1701,7 @@ class Language: deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( b, exclude=["vocab"] ) - for name, proc in self.components: + for name, proc in self._components: if name in exclude: continue if not hasattr(proc, "from_bytes"): diff --git a/spacy/pipeline/attributeruler.py b/spacy/pipeline/attributeruler.py index d93afc642..374e0d046 100644 --- a/spacy/pipeline/attributeruler.py +++ b/spacy/pipeline/attributeruler.py @@ -12,6 +12,7 @@ from ..symbols import IDS, TAG, POS, MORPH, LEMMA from ..tokens import Doc, Span from ..tokens._retokenize import normalize_token_attrs, set_token_attrs from ..vocab import Vocab +from ..util import SimpleFrozenList from .. import util @@ -220,7 +221,7 @@ class AttributeRuler(Pipe): results.update(Scorer.score_token_attr(examples, "lemma", **kwargs)) return results - def to_bytes(self, exclude: Iterable[str] = tuple()) -> bytes: + def to_bytes(self, exclude: Iterable[str] = SimpleFrozenList()) -> bytes: """Serialize the AttributeRuler to a bytestring. exclude (Iterable[str]): String names of serialization fields to exclude. @@ -236,7 +237,9 @@ class AttributeRuler(Pipe): serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices) return util.to_bytes(serialize, exclude) - def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()): + def from_bytes( + self, bytes_data: bytes, exclude: Iterable[str] = SimpleFrozenList() + ): """Load the AttributeRuler from a bytestring. bytes_data (bytes): The data to load. @@ -272,7 +275,9 @@ class AttributeRuler(Pipe): return self - def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None: + def to_disk( + self, path: Union[Path, str], exclude: Iterable[str] = SimpleFrozenList() + ) -> None: """Serialize the AttributeRuler to disk. path (Union[Path, str]): A path to a directory. @@ -289,7 +294,7 @@ class AttributeRuler(Pipe): util.to_disk(path, serialize, exclude) def from_disk( - self, path: Union[Path, str], exclude: Iterable[str] = tuple() + self, path: Union[Path, str], exclude: Iterable[str] = SimpleFrozenList() ) -> None: """Load the AttributeRuler from disk. diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index d92c700ba..ae4838bed 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -13,6 +13,7 @@ from ..language import Language from ..vocab import Vocab from ..gold import Example, validate_examples from ..errors import Errors, Warnings +from ..util import SimpleFrozenList from .. import util @@ -404,7 +405,7 @@ class EntityLinker(Pipe): token.ent_kb_id_ = kb_id def to_disk( - self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() + self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList(), ) -> None: """Serialize the pipe to disk. @@ -421,7 +422,7 @@ class EntityLinker(Pipe): util.to_disk(path, serialize, exclude) def from_disk( - self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() + self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList(), ) -> "EntityLinker": """Load the pipe from disk. Modifies the object in place and returns it. diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 785d17b6b..5137dfec2 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -5,7 +5,7 @@ import srsly from ..language import Language from ..errors import Errors -from ..util import ensure_path, to_disk, from_disk +from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList from ..tokens import Doc, Span from ..matcher import Matcher, PhraseMatcher from ..scorer import Scorer @@ -317,7 +317,7 @@ class EntityRuler: return Scorer.score_spans(examples, "ents", **kwargs) def from_bytes( - self, patterns_bytes: bytes, *, exclude: Iterable[str] = tuple() + self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList() ) -> "EntityRuler": """Load the entity ruler from a bytestring. @@ -341,7 +341,7 @@ class EntityRuler: self.add_patterns(cfg) return self - def to_bytes(self, *, exclude: Iterable[str] = tuple()) -> bytes: + def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes: """Serialize the entity ruler patterns to a bytestring. RETURNS (bytes): The serialized patterns. @@ -357,7 +357,7 @@ class EntityRuler: return srsly.msgpack_dumps(serial) def from_disk( - self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() + self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() ) -> "EntityRuler": """Load the entity ruler from a file. Expects a file containing newline-delimited JSON (JSONL) with one entry per line. @@ -394,7 +394,7 @@ class EntityRuler: return self def to_disk( - self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() + self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() ) -> None: """Save the entity ruler patterns to a directory. The patterns will be saved as newline-delimited JSON (JSONL). diff --git a/spacy/scorer.py b/spacy/scorer.py index 95fb21168..9bbc64cac 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -1,10 +1,10 @@ -from typing import Optional, Iterable, Dict, Any, Callable, Tuple, TYPE_CHECKING +from typing import Optional, Iterable, Dict, Any, Callable, TYPE_CHECKING import numpy as np from .gold import Example from .tokens import Token, Doc, Span from .errors import Errors -from .util import get_lang_class +from .util import get_lang_class, SimpleFrozenList from .morphology import Morphology if TYPE_CHECKING: @@ -317,7 +317,7 @@ class Scorer: attr: str, *, getter: Callable[[Doc, str], Any] = getattr, - labels: Iterable[str] = tuple(), + labels: Iterable[str] = SimpleFrozenList(), multi_label: bool = True, positive_label: Optional[str] = None, threshold: Optional[float] = None, @@ -447,7 +447,7 @@ class Scorer: getter: Callable[[Token, str], Any] = getattr, head_attr: str = "head", head_getter: Callable[[Token, str], Token] = getattr, - ignore_labels: Tuple[str] = tuple(), + ignore_labels: Iterable[str] = SimpleFrozenList(), **cfg, ) -> Dict[str, Any]: """Returns the UAS, LAS, and LAS per type scores for dependency diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 2a1cbad2a..ea09d990c 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -1,5 +1,6 @@ import pytest from spacy.language import Language +from spacy.util import SimpleFrozenList @pytest.fixture @@ -317,3 +318,31 @@ def test_disable_enable_pipes(): assert nlp.config["nlp"]["disabled"] == [name] nlp("?") assert results[f"{name}1"] == "!" + + +def test_pipe_methods_frozen(): + """Test that spaCy raises custom error messages if "frozen" properties are + accessed. We still want to use a list here to not break backwards + compatibility, but users should see an error if they're trying to append + to nlp.pipeline etc.""" + nlp = Language() + ner = nlp.add_pipe("ner") + assert nlp.pipe_names == ["ner"] + for prop in [ + nlp.pipeline, + nlp.pipe_names, + nlp.components, + nlp.component_names, + nlp.disabled, + nlp.factory_names, + ]: + assert isinstance(prop, list) + assert isinstance(prop, SimpleFrozenList) + with pytest.raises(NotImplementedError): + nlp.pipeline.append(("ner2", ner)) + with pytest.raises(NotImplementedError): + nlp.pipe_names.pop() + with pytest.raises(NotImplementedError): + nlp.components.sort() + with pytest.raises(NotImplementedError): + nlp.component_names.clear() diff --git a/spacy/tests/test_util.py b/spacy/tests/test_util.py index 47111a902..40cd71eb5 100644 --- a/spacy/tests/test_util.py +++ b/spacy/tests/test_util.py @@ -3,10 +3,9 @@ import pytest from .util import get_random_doc from spacy import util -from spacy.util import dot_to_object +from spacy.util import dot_to_object, SimpleFrozenList from thinc.api import Config, Optimizer from spacy.gold.batchers import minibatch_by_words - from ..lang.en import English from ..lang.nl import Dutch from ..language import DEFAULT_CONFIG_PATH @@ -106,3 +105,20 @@ def test_util_dot_section(): assert not dot_to_object(en_config, "nlp.load_vocab_data") assert dot_to_object(nl_config, "nlp.load_vocab_data") assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer) + + +def test_simple_frozen_list(): + t = SimpleFrozenList(["foo", "bar"]) + assert t == ["foo", "bar"] + assert t.index("bar") == 1 # okay method + with pytest.raises(NotImplementedError): + t.append("baz") + with pytest.raises(NotImplementedError): + t.sort() + with pytest.raises(NotImplementedError): + t.extend(["baz"]) + with pytest.raises(NotImplementedError): + t.pop() + t = SimpleFrozenList(["foo", "bar"], error="Error!") + with pytest.raises(NotImplementedError): + t.append("baz") diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py index 9d17cec1c..a257c7919 100644 --- a/spacy/tokens/_serialize.py +++ b/spacy/tokens/_serialize.py @@ -10,7 +10,7 @@ from ..vocab import Vocab from ..compat import copy_reg from ..attrs import SPACY, ORTH, intify_attr from ..errors import Errors -from ..util import ensure_path +from ..util import ensure_path, SimpleFrozenList # fmt: off ALL_ATTRS = ("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS") @@ -52,7 +52,7 @@ class DocBin: self, attrs: Iterable[str] = ALL_ATTRS, store_user_data: bool = False, - docs: Iterable[Doc] = tuple(), + docs: Iterable[Doc] = SimpleFrozenList(), ) -> None: """Create a DocBin object to hold serialized annotations. diff --git a/spacy/util.py b/spacy/util.py index d12e54dc7..6f0bf9b00 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -120,6 +120,47 @@ class SimpleFrozenDict(dict): raise NotImplementedError(self.error) +class SimpleFrozenList(list): + """Wrapper class around a list that lets us raise custom errors if certain + attributes/methods are accessed. Mostly used for properties like + Language.pipeline that return an immutable list (and that we don't want to + convert to a tuple to not break too much backwards compatibility). If a user + accidentally calls nlp.pipeline.append(), we can raise a more helpful error. + """ + + def __init__(self, *args, error: str = Errors.E927) -> None: + """Initialize the frozen list. + + error (str): The error message when user tries to mutate the list. + """ + self.error = error + super().__init__(*args) + + def append(self, *args, **kwargs): + raise NotImplementedError(self.error) + + def clear(self, *args, **kwargs): + raise NotImplementedError(self.error) + + def extend(self, *args, **kwargs): + raise NotImplementedError(self.error) + + def insert(self, *args, **kwargs): + raise NotImplementedError(self.error) + + def pop(self, *args, **kwargs): + raise NotImplementedError(self.error) + + def remove(self, *args, **kwargs): + raise NotImplementedError(self.error) + + def reverse(self, *args, **kwargs): + raise NotImplementedError(self.error) + + def sort(self, *args, **kwargs): + raise NotImplementedError(self.error) + + def lang_class_is_loaded(lang: str) -> bool: """Check whether a Language class is already loaded. Language classes are loaded lazily, to avoid expensive setup code associated with the language @@ -215,8 +256,8 @@ def load_model( name: Union[str, Path], *, vocab: Union["Vocab", bool] = True, - disable: Iterable[str] = tuple(), - exclude: Iterable[str] = tuple(), + disable: Iterable[str] = SimpleFrozenList(), + exclude: Iterable[str] = SimpleFrozenList(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), ) -> "Language": """Load a model from a package or data path. @@ -248,8 +289,8 @@ def load_model_from_package( name: str, *, vocab: Union["Vocab", bool] = True, - disable: Iterable[str] = tuple(), - exclude: Iterable[str] = tuple(), + disable: Iterable[str] = SimpleFrozenList(), + exclude: Iterable[str] = SimpleFrozenList(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), ) -> "Language": """Load a model from an installed package. @@ -275,8 +316,8 @@ def load_model_from_path( *, meta: Optional[Dict[str, Any]] = None, vocab: Union["Vocab", bool] = True, - disable: Iterable[str] = tuple(), - exclude: Iterable[str] = tuple(), + disable: Iterable[str] = SimpleFrozenList(), + exclude: Iterable[str] = SimpleFrozenList(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), ) -> "Language": """Load a model from a data directory path. Creates Language class with @@ -311,8 +352,8 @@ def load_model_from_config( config: Union[Dict[str, Any], Config], *, vocab: Union["Vocab", bool] = True, - disable: Iterable[str] = tuple(), - exclude: Iterable[str] = tuple(), + disable: Iterable[str] = SimpleFrozenList(), + exclude: Iterable[str] = SimpleFrozenList(), auto_fill: bool = False, validate: bool = True, ) -> Tuple["Language", Config]: @@ -355,8 +396,8 @@ def load_model_from_init_py( init_file: Union[Path, str], *, vocab: Union["Vocab", bool] = True, - disable: Iterable[str] = tuple(), - exclude: Iterable[str] = tuple(), + disable: Iterable[str] = SimpleFrozenList(), + exclude: Iterable[str] = SimpleFrozenList(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), ) -> "Language": """Helper function to use in the `load()` method of a model package's