mirror of https://github.com/explosion/spaCy.git
Update Language.initialize and support components/tokenizer settings
This commit is contained in:
parent
4925ad760a
commit
dec984a9c1
|
@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES
|
|||
from .tokens import Doc
|
||||
from .tokenizer import Tokenizer
|
||||
from .errors import Errors, Warnings
|
||||
from .schemas import ConfigSchema, ConfigSchemaNlp
|
||||
from .schemas import ConfigSchema, ConfigSchemaNlp, validate_init_settings
|
||||
from .git_info import GIT_VERSION
|
||||
from . import util
|
||||
from . import about
|
||||
|
@ -1162,6 +1162,7 @@ class Language:
|
|||
self,
|
||||
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
|
||||
*,
|
||||
settings: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
sgd: Optional[Optimizer] = None,
|
||||
device: int = -1,
|
||||
) -> Optimizer:
|
||||
|
@ -1207,10 +1208,26 @@ class Language:
|
|||
if sgd is None:
|
||||
sgd = create_default_optimizer()
|
||||
self._optimizer = sgd
|
||||
if hasattr(self.tokenizer, "initialize"):
|
||||
tok_settings = settings.get("tokenizer", {})
|
||||
tok_settings = validate_init_settings(
|
||||
self.tokenizer.initialize,
|
||||
tok_settings,
|
||||
section="tokenizer",
|
||||
name="tokenizer",
|
||||
)
|
||||
self.tokenizer.initialize(get_examples, nlp=self, **tok_settings)
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "initialize"):
|
||||
p_settings = settings.get(name, {})
|
||||
p_settings = validate_init_settings(
|
||||
proc.initialize, p_settings, section="components", name=name
|
||||
)
|
||||
proc.initialize(
|
||||
get_examples, pipeline=self.pipeline, sgd=self._optimizer
|
||||
get_examples,
|
||||
pipeline=self.pipeline,
|
||||
sgd=self._optimizer,
|
||||
**p_settings,
|
||||
)
|
||||
self._link_components()
|
||||
return self._optimizer
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
||||
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
|
||||
from __future__ import print_function
|
||||
from cymem.cymem cimport Pool
|
||||
cimport numpy as np
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from typing import Dict, List, Union, Optional, Any, Callable, Type, Tuple
|
||||
from typing import Iterable, TypeVar, TYPE_CHECKING
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
from pydantic import BaseModel, Field, ValidationError, validator, create_model
|
||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||
from pydantic.main import ModelMetaclass
|
||||
from thinc.api import Optimizer, ConfigValidationError
|
||||
from thinc.config import Promise
|
||||
from collections import defaultdict
|
||||
from thinc.api import Optimizer
|
||||
import inspect
|
||||
|
||||
from .attrs import NAMES
|
||||
from .lookups import Lookups
|
||||
|
@ -43,6 +45,93 @@ def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
|
|||
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
|
||||
|
||||
|
||||
# Initialization
|
||||
|
||||
|
||||
class ArgSchemaConfig:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ArgSchemaConfigExtra:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def get_arg_model(
|
||||
func: Callable,
|
||||
*,
|
||||
exclude: Iterable[str] = tuple(),
|
||||
name: str = "ArgModel",
|
||||
strict: bool = True,
|
||||
) -> ModelMetaclass:
|
||||
"""Generate a pydantic model for function arguments.
|
||||
|
||||
func (Callable): The function to generate the schema for.
|
||||
exclude (Iterable[str]): Parameter names to ignore.
|
||||
name (str): Name of created model class.
|
||||
strict (bool): Don't allow extra arguments if no variable keyword arguments
|
||||
are allowed on the function.
|
||||
RETURNS (ModelMetaclass): A pydantic model.
|
||||
"""
|
||||
sig_args = {}
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
except ValueError:
|
||||
# Typically happens if the method is part of a Cython module without
|
||||
# binding=True. Here we just use an empty model that allows everything.
|
||||
return create_model(name, __config__=ArgSchemaConfigExtra)
|
||||
has_variable = False
|
||||
for param in sig.parameters.values():
|
||||
if param.name in exclude:
|
||||
continue
|
||||
if param.kind == param.VAR_KEYWORD:
|
||||
# The function allows variable keyword arguments so we shouldn't
|
||||
# include **kwargs etc. in the schema and switch to non-strict
|
||||
# mode and pass through all other values
|
||||
has_variable = True
|
||||
continue
|
||||
# If no annotation is specified assume it's anything
|
||||
annotation = param.annotation if param.annotation != param.empty else Any
|
||||
# If no default value is specified assume that it's required
|
||||
default = param.default if param.default != param.empty else ...
|
||||
sig_args[param.name] = (annotation, default)
|
||||
is_strict = strict and not has_variable
|
||||
sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra
|
||||
return create_model(name, **sig_args)
|
||||
|
||||
|
||||
def validate_init_settings(
|
||||
func: Callable,
|
||||
settings: Dict[str, Any],
|
||||
*,
|
||||
section: Optional[str] = None,
|
||||
name: str = "",
|
||||
exclude: Iterable[str] = ("get_examples", "pipeline", "sgd"),
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate initialization settings against the expected arguments in
|
||||
the method signature. Will parse values if possible (e.g. int to string)
|
||||
and return the updated settings dict. Will raise a ConfigValidationError
|
||||
if types don't match or required values are missing.
|
||||
|
||||
func (Callable): The initialize method of a given component etc.
|
||||
settings (Dict[str, Any]): The settings from the repsective [initialize] block.
|
||||
section (str): Initialize section, for error message.
|
||||
name (str): Name of the block in the section.
|
||||
exclude (Iterable[str]): Parameter names to exclude from schema.
|
||||
RETURNS (Dict[str, Any]): The validated settings.
|
||||
"""
|
||||
schema = get_arg_model(func, exclude=exclude, name="InitArgModel")
|
||||
try:
|
||||
return schema(**settings).dict()
|
||||
except ValidationError as e:
|
||||
block = "initialize" if not section else f"initialize.{section}"
|
||||
title = f"Error validating initialization settings in [{block}]"
|
||||
raise ConfigValidationError(
|
||||
title=title, errors=e.errors(), config=settings, parent=name,
|
||||
) from None
|
||||
|
||||
|
||||
# Matcher token patterns
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue