Add registry for model creation functions ('architectures') (#4395)

* Add architecture registry

* Add test for arch registry

* Add error for model architectures
This commit is contained in:
Matthew Honnibal 2019-10-08 12:21:03 +02:00 committed by Ines Montani
parent 650cbfe82d
commit ddd6fda59c
4 changed files with 51 additions and 0 deletions

View File

@ -14,6 +14,7 @@ from .glossary import explain
from .about import __version__
from .errors import Errors, Warnings, deprecation_warning
from . import util
from .util import register_architecture, get_architecture
if sys.maxunicode == 65535:

View File

@ -496,6 +496,7 @@ class Errors(object):
E173 = ("As of v2.2, the Lemmatizer is initialized with an instance of "
"Lookups containing the lemmatization tables. See the docs for "
"details: https://spacy.io/api/lemmatizer#init")
E174 = ("Architecture {name} not found in registry. Available names: {names}")
@add_codes

View File

@ -0,0 +1,19 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy import register_architecture
from spacy import get_architecture
from thinc.v2v import Affine
@register_architecture("my_test_function")
def create_model(nr_in, nr_out):
return Affine(nr_in, nr_out)
def test_get_architecture():
arch = get_architecture("my_test_function")
assert arch is create_model
with pytest.raises(KeyError):
get_architecture("not_an_existing_key")

View File

@ -32,6 +32,7 @@ from .errors import Errors, Warnings, deprecation_warning
LANGUAGES = {}
ARCHITECTURES = {}
_data_path = Path(__file__).parent / "data"
_PRINT_ENV = False
@ -48,6 +49,7 @@ class ENTRY_POINTS(object):
languages = "spacy_languages"
displacy_colors = "spacy_displacy_colors"
lookups = "spacy_lookups"
architectures = "spacy_architectures"
def set_env_log(value):
@ -119,6 +121,34 @@ def set_data_path(path):
_data_path = ensure_path(path)
def register_architecture(name, arch=None):
"""Decorator to register an architecture. An architecture is a function
that returns a Thinc Model object.
"""
global ARCHITECTURES
if arch is not None:
ARCHITECTURES[name] = arch
return arch
def do_registration(arch):
ARCHITECTURES[name] = arch
return arch
return do_registration
def get_architecture(name):
"""Get a model architecture function by name."""
# Check if an entry point is exposed for the architecture code
entry_point = get_entry_point(ENTRY_POINTS.architectures, name)
if entry_point is not None:
ARCHITECTURES[name] = entry_point
if name not in ARCHITECTURES:
names = ", ".join(sorted(ARCHITECTURES.keys()))
raise KeyError(Errors.E174.format(name=name, names=names))
return ARCHITECTURES[name]
def ensure_path(path):
"""Ensure string is converted to a Path.