mirror of https://github.com/explosion/spaCy.git
Add registry for model creation functions ('architectures') (#4395)
* Add architecture registry * Add test for arch registry * Add error for model architectures
This commit is contained in:
parent
650cbfe82d
commit
ddd6fda59c
|
@ -14,6 +14,7 @@ from .glossary import explain
|
||||||
from .about import __version__
|
from .about import __version__
|
||||||
from .errors import Errors, Warnings, deprecation_warning
|
from .errors import Errors, Warnings, deprecation_warning
|
||||||
from . import util
|
from . import util
|
||||||
|
from .util import register_architecture, get_architecture
|
||||||
|
|
||||||
|
|
||||||
if sys.maxunicode == 65535:
|
if sys.maxunicode == 65535:
|
||||||
|
|
|
@ -496,6 +496,7 @@ class Errors(object):
|
||||||
E173 = ("As of v2.2, the Lemmatizer is initialized with an instance of "
|
E173 = ("As of v2.2, the Lemmatizer is initialized with an instance of "
|
||||||
"Lookups containing the lemmatization tables. See the docs for "
|
"Lookups containing the lemmatization tables. See the docs for "
|
||||||
"details: https://spacy.io/api/lemmatizer#init")
|
"details: https://spacy.io/api/lemmatizer#init")
|
||||||
|
E174 = ("Architecture {name} not found in registry. Available names: {names}")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from spacy import register_architecture
|
||||||
|
from spacy import get_architecture
|
||||||
|
from thinc.v2v import Affine
|
||||||
|
|
||||||
|
|
||||||
|
@register_architecture("my_test_function")
|
||||||
|
def create_model(nr_in, nr_out):
|
||||||
|
return Affine(nr_in, nr_out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_architecture():
|
||||||
|
arch = get_architecture("my_test_function")
|
||||||
|
assert arch is create_model
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
get_architecture("not_an_existing_key")
|
|
@ -32,6 +32,7 @@ from .errors import Errors, Warnings, deprecation_warning
|
||||||
|
|
||||||
|
|
||||||
LANGUAGES = {}
|
LANGUAGES = {}
|
||||||
|
ARCHITECTURES = {}
|
||||||
_data_path = Path(__file__).parent / "data"
|
_data_path = Path(__file__).parent / "data"
|
||||||
_PRINT_ENV = False
|
_PRINT_ENV = False
|
||||||
|
|
||||||
|
@ -48,6 +49,7 @@ class ENTRY_POINTS(object):
|
||||||
languages = "spacy_languages"
|
languages = "spacy_languages"
|
||||||
displacy_colors = "spacy_displacy_colors"
|
displacy_colors = "spacy_displacy_colors"
|
||||||
lookups = "spacy_lookups"
|
lookups = "spacy_lookups"
|
||||||
|
architectures = "spacy_architectures"
|
||||||
|
|
||||||
|
|
||||||
def set_env_log(value):
|
def set_env_log(value):
|
||||||
|
@ -119,6 +121,34 @@ def set_data_path(path):
|
||||||
_data_path = ensure_path(path)
|
_data_path = ensure_path(path)
|
||||||
|
|
||||||
|
|
||||||
|
def register_architecture(name, arch=None):
|
||||||
|
"""Decorator to register an architecture. An architecture is a function
|
||||||
|
that returns a Thinc Model object.
|
||||||
|
"""
|
||||||
|
global ARCHITECTURES
|
||||||
|
if arch is not None:
|
||||||
|
ARCHITECTURES[name] = arch
|
||||||
|
return arch
|
||||||
|
|
||||||
|
def do_registration(arch):
|
||||||
|
ARCHITECTURES[name] = arch
|
||||||
|
return arch
|
||||||
|
|
||||||
|
return do_registration
|
||||||
|
|
||||||
|
|
||||||
|
def get_architecture(name):
|
||||||
|
"""Get a model architecture function by name."""
|
||||||
|
# Check if an entry point is exposed for the architecture code
|
||||||
|
entry_point = get_entry_point(ENTRY_POINTS.architectures, name)
|
||||||
|
if entry_point is not None:
|
||||||
|
ARCHITECTURES[name] = entry_point
|
||||||
|
if name not in ARCHITECTURES:
|
||||||
|
names = ", ".join(sorted(ARCHITECTURES.keys()))
|
||||||
|
raise KeyError(Errors.E174.format(name=name, names=names))
|
||||||
|
return ARCHITECTURES[name]
|
||||||
|
|
||||||
|
|
||||||
def ensure_path(path):
|
def ensure_path(path):
|
||||||
"""Ensure string is converted to a Path.
|
"""Ensure string is converted to a Path.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue