From ddd6fda59cb5499729c936400998e0137c995bf1 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 8 Oct 2019 12:21:03 +0200 Subject: [PATCH] Add registry for model creation functions ('architectures') (#4395) * Add architecture registry * Add test for arch registry * Add error for model architectures --- spacy/__init__.py | 1 + spacy/errors.py | 1 + spacy/tests/test_register_architecture.py | 19 ++++++++++++++ spacy/util.py | 30 +++++++++++++++++++++++ 4 files changed, 51 insertions(+) create mode 100644 spacy/tests/test_register_architecture.py diff --git a/spacy/__init__.py b/spacy/__init__.py index 9edbab198..8930b1d4e 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -14,6 +14,7 @@ from .glossary import explain from .about import __version__ from .errors import Errors, Warnings, deprecation_warning from . import util +from .util import register_architecture, get_architecture if sys.maxunicode == 65535: diff --git a/spacy/errors.py b/spacy/errors.py index a4b16f6fa..de93eaf2e 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -496,6 +496,7 @@ class Errors(object): E173 = ("As of v2.2, the Lemmatizer is initialized with an instance of " "Lookups containing the lemmatization tables. See the docs for " "details: https://spacy.io/api/lemmatizer#init") + E174 = ("Architecture {name} not found in registry. Available names: {names}") @add_codes diff --git a/spacy/tests/test_register_architecture.py b/spacy/tests/test_register_architecture.py new file mode 100644 index 000000000..0c1b5b16f --- /dev/null +++ b/spacy/tests/test_register_architecture.py @@ -0,0 +1,19 @@ +# coding: utf8 +from __future__ import unicode_literals + +import pytest +from spacy import register_architecture +from spacy import get_architecture +from thinc.v2v import Affine + + +@register_architecture("my_test_function") +def create_model(nr_in, nr_out): + return Affine(nr_in, nr_out) + + +def test_get_architecture(): + arch = get_architecture("my_test_function") + assert arch is create_model + with pytest.raises(KeyError): + get_architecture("not_an_existing_key") diff --git a/spacy/util.py b/spacy/util.py index 39cb73c05..d56f39a78 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -32,6 +32,7 @@ from .errors import Errors, Warnings, deprecation_warning LANGUAGES = {} +ARCHITECTURES = {} _data_path = Path(__file__).parent / "data" _PRINT_ENV = False @@ -48,6 +49,7 @@ class ENTRY_POINTS(object): languages = "spacy_languages" displacy_colors = "spacy_displacy_colors" lookups = "spacy_lookups" + architectures = "spacy_architectures" def set_env_log(value): @@ -119,6 +121,34 @@ def set_data_path(path): _data_path = ensure_path(path) +def register_architecture(name, arch=None): + """Decorator to register an architecture. An architecture is a function + that returns a Thinc Model object. + """ + global ARCHITECTURES + if arch is not None: + ARCHITECTURES[name] = arch + return arch + + def do_registration(arch): + ARCHITECTURES[name] = arch + return arch + + return do_registration + + +def get_architecture(name): + """Get a model architecture function by name.""" + # Check if an entry point is exposed for the architecture code + entry_point = get_entry_point(ENTRY_POINTS.architectures, name) + if entry_point is not None: + ARCHITECTURES[name] = entry_point + if name not in ARCHITECTURES: + names = ", ".join(sorted(ARCHITECTURES.keys())) + raise KeyError(Errors.E174.format(name=name, names=names)) + return ARCHITECTURES[name] + + def ensure_path(path): """Ensure string is converted to a Path.