Update Language.initialize and support components/tokenizer settings

This commit is contained in:
Ines Montani 2020-09-29 11:52:45 +02:00
parent 4925ad760a
commit dec984a9c1
3 changed files with 111 additions and 5 deletions

View File

@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES
from .tokens import Doc
from .tokenizer import Tokenizer
from .errors import Errors, Warnings
from .schemas import ConfigSchema, ConfigSchemaNlp
from .schemas import ConfigSchema, ConfigSchemaNlp, validate_init_settings
from .git_info import GIT_VERSION
from . import util
from . import about
@ -1162,6 +1162,7 @@ class Language:
self,
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
*,
settings: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
sgd: Optional[Optimizer] = None,
device: int = -1,
) -> Optimizer:
@ -1207,10 +1208,26 @@ class Language:
if sgd is None:
sgd = create_default_optimizer()
self._optimizer = sgd
if hasattr(self.tokenizer, "initialize"):
tok_settings = settings.get("tokenizer", {})
tok_settings = validate_init_settings(
self.tokenizer.initialize,
tok_settings,
section="tokenizer",
name="tokenizer",
)
self.tokenizer.initialize(get_examples, nlp=self, **tok_settings)
for name, proc in self.pipeline:
if hasattr(proc, "initialize"):
p_settings = settings.get(name, {})
p_settings = validate_init_settings(
proc.initialize, p_settings, section="components", name=name
)
proc.initialize(
get_examples, pipeline=self.pipeline, sgd=self._optimizer
get_examples,
pipeline=self.pipeline,
sgd=self._optimizer,
**p_settings,
)
self._link_components()
return self._optimizer

View File

@ -1,4 +1,4 @@
# cython: infer_types=True, cdivision=True, boundscheck=False
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
from __future__ import print_function
from cymem.cymem cimport Pool
cimport numpy as np

View File

@ -1,11 +1,13 @@
from typing import Dict, List, Union, Optional, Any, Callable, Type, Tuple
from typing import Iterable, TypeVar, TYPE_CHECKING
from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import BaseModel, Field, ValidationError, validator, create_model
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
from pydantic.main import ModelMetaclass
from thinc.api import Optimizer, ConfigValidationError
from thinc.config import Promise
from collections import defaultdict
from thinc.api import Optimizer
import inspect
from .attrs import NAMES
from .lookups import Lookups
@ -43,6 +45,93 @@ def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
# Initialization
class ArgSchemaConfig:
extra = "forbid"
arbitrary_types_allowed = True
class ArgSchemaConfigExtra:
extra = "forbid"
arbitrary_types_allowed = True
def get_arg_model(
func: Callable,
*,
exclude: Iterable[str] = tuple(),
name: str = "ArgModel",
strict: bool = True,
) -> ModelMetaclass:
"""Generate a pydantic model for function arguments.
func (Callable): The function to generate the schema for.
exclude (Iterable[str]): Parameter names to ignore.
name (str): Name of created model class.
strict (bool): Don't allow extra arguments if no variable keyword arguments
are allowed on the function.
RETURNS (ModelMetaclass): A pydantic model.
"""
sig_args = {}
try:
sig = inspect.signature(func)
except ValueError:
# Typically happens if the method is part of a Cython module without
# binding=True. Here we just use an empty model that allows everything.
return create_model(name, __config__=ArgSchemaConfigExtra)
has_variable = False
for param in sig.parameters.values():
if param.name in exclude:
continue
if param.kind == param.VAR_KEYWORD:
# The function allows variable keyword arguments so we shouldn't
# include **kwargs etc. in the schema and switch to non-strict
# mode and pass through all other values
has_variable = True
continue
# If no annotation is specified assume it's anything
annotation = param.annotation if param.annotation != param.empty else Any
# If no default value is specified assume that it's required
default = param.default if param.default != param.empty else ...
sig_args[param.name] = (annotation, default)
is_strict = strict and not has_variable
sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra
return create_model(name, **sig_args)
def validate_init_settings(
func: Callable,
settings: Dict[str, Any],
*,
section: Optional[str] = None,
name: str = "",
exclude: Iterable[str] = ("get_examples", "pipeline", "sgd"),
) -> Dict[str, Any]:
"""Validate initialization settings against the expected arguments in
the method signature. Will parse values if possible (e.g. int to string)
and return the updated settings dict. Will raise a ConfigValidationError
if types don't match or required values are missing.
func (Callable): The initialize method of a given component etc.
settings (Dict[str, Any]): The settings from the repsective [initialize] block.
section (str): Initialize section, for error message.
name (str): Name of the block in the section.
exclude (Iterable[str]): Parameter names to exclude from schema.
RETURNS (Dict[str, Any]): The validated settings.
"""
schema = get_arg_model(func, exclude=exclude, name="InitArgModel")
try:
return schema(**settings).dict()
except ValidationError as e:
block = "initialize" if not section else f"initialize.{section}"
title = f"Error validating initialization settings in [{block}]"
raise ConfigValidationError(
title=title, errors=e.errors(), config=settings, parent=name,
) from None
# Matcher token patterns