mirror of https://github.com/explosion/spaCy.git
Add registry for model creation functions ('architectures') (#4395)
* Add architecture registry * Add test for arch registry * Add error for model architectures
This commit is contained in:
parent
650cbfe82d
commit
ddd6fda59c
|
@ -14,6 +14,7 @@ from .glossary import explain
|
|||
from .about import __version__
|
||||
from .errors import Errors, Warnings, deprecation_warning
|
||||
from . import util
|
||||
from .util import register_architecture, get_architecture
|
||||
|
||||
|
||||
if sys.maxunicode == 65535:
|
||||
|
|
|
@ -496,6 +496,7 @@ class Errors(object):
|
|||
E173 = ("As of v2.2, the Lemmatizer is initialized with an instance of "
|
||||
"Lookups containing the lemmatization tables. See the docs for "
|
||||
"details: https://spacy.io/api/lemmatizer#init")
|
||||
E174 = ("Architecture {name} not found in registry. Available names: {names}")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy import register_architecture
|
||||
from spacy import get_architecture
|
||||
from thinc.v2v import Affine
|
||||
|
||||
|
||||
@register_architecture("my_test_function")
|
||||
def create_model(nr_in, nr_out):
|
||||
return Affine(nr_in, nr_out)
|
||||
|
||||
|
||||
def test_get_architecture():
|
||||
arch = get_architecture("my_test_function")
|
||||
assert arch is create_model
|
||||
with pytest.raises(KeyError):
|
||||
get_architecture("not_an_existing_key")
|
|
@ -32,6 +32,7 @@ from .errors import Errors, Warnings, deprecation_warning
|
|||
|
||||
|
||||
LANGUAGES = {}
|
||||
ARCHITECTURES = {}
|
||||
_data_path = Path(__file__).parent / "data"
|
||||
_PRINT_ENV = False
|
||||
|
||||
|
@ -48,6 +49,7 @@ class ENTRY_POINTS(object):
|
|||
languages = "spacy_languages"
|
||||
displacy_colors = "spacy_displacy_colors"
|
||||
lookups = "spacy_lookups"
|
||||
architectures = "spacy_architectures"
|
||||
|
||||
|
||||
def set_env_log(value):
|
||||
|
@ -119,6 +121,34 @@ def set_data_path(path):
|
|||
_data_path = ensure_path(path)
|
||||
|
||||
|
||||
def register_architecture(name, arch=None):
|
||||
"""Decorator to register an architecture. An architecture is a function
|
||||
that returns a Thinc Model object.
|
||||
"""
|
||||
global ARCHITECTURES
|
||||
if arch is not None:
|
||||
ARCHITECTURES[name] = arch
|
||||
return arch
|
||||
|
||||
def do_registration(arch):
|
||||
ARCHITECTURES[name] = arch
|
||||
return arch
|
||||
|
||||
return do_registration
|
||||
|
||||
|
||||
def get_architecture(name):
|
||||
"""Get a model architecture function by name."""
|
||||
# Check if an entry point is exposed for the architecture code
|
||||
entry_point = get_entry_point(ENTRY_POINTS.architectures, name)
|
||||
if entry_point is not None:
|
||||
ARCHITECTURES[name] = entry_point
|
||||
if name not in ARCHITECTURES:
|
||||
names = ", ".join(sorted(ARCHITECTURES.keys()))
|
||||
raise KeyError(Errors.E174.format(name=name, names=names))
|
||||
return ARCHITECTURES[name]
|
||||
|
||||
|
||||
def ensure_path(path):
|
||||
"""Ensure string is converted to a Path.
|
||||
|
||||
|
|
Loading…
Reference in New Issue