mirror of https://github.com/explosion/spaCy.git
Fix is_cython_func for additional imported code
* Fix `is_cython_func` for imported code loaded under `python_code` module name * Add `make_named_tempfile` context manager to test utils to test loading of imported code * Add test for validation of `initialize` params in custom module
This commit is contained in:
parent
10c930cc96
commit
e9f7f9a4bc
|
@ -7,7 +7,7 @@ from spacy import util
|
||||||
from spacy import prefer_gpu, require_gpu, require_cpu
|
from spacy import prefer_gpu, require_gpu, require_cpu
|
||||||
from spacy.ml._precomputable_affine import PrecomputableAffine
|
from spacy.ml._precomputable_affine import PrecomputableAffine
|
||||||
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
|
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
|
||||||
from spacy.util import dot_to_object, SimpleFrozenList
|
from spacy.util import dot_to_object, SimpleFrozenList, import_file
|
||||||
from thinc.api import Config, Optimizer, ConfigValidationError
|
from thinc.api import Config, Optimizer, ConfigValidationError
|
||||||
from spacy.training.batchers import minibatch_by_words
|
from spacy.training.batchers import minibatch_by_words
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
@ -17,7 +17,7 @@ from spacy.schemas import ConfigSchemaTraining
|
||||||
|
|
||||||
from thinc.api import get_current_ops, NumpyOps, CupyOps
|
from thinc.api import get_current_ops, NumpyOps, CupyOps
|
||||||
|
|
||||||
from .util import get_random_doc
|
from .util import get_random_doc, make_named_tempfile
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -347,3 +347,34 @@ def test_resolve_dot_names():
|
||||||
errors = e.value.errors
|
errors = e.value.errors
|
||||||
assert len(errors) == 1
|
assert len(errors) == 1
|
||||||
assert errors[0]["loc"] == ["training", "xyz"]
|
assert errors[0]["loc"] == ["training", "xyz"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_code():
|
||||||
|
code_str = """
|
||||||
|
from spacy import Language
|
||||||
|
|
||||||
|
class DummyComponent:
|
||||||
|
def __init__(self, vocab, name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def initialize(self, get_examples, *, nlp, dummy_param: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@Language.factory(
|
||||||
|
"dummy_component",
|
||||||
|
)
|
||||||
|
def make_dummy_component(
|
||||||
|
nlp: Language, name: str
|
||||||
|
):
|
||||||
|
return DummyComponent(nlp.vocab, name)
|
||||||
|
"""
|
||||||
|
|
||||||
|
with make_named_tempfile(mode="w", suffix=".py") as fileh:
|
||||||
|
fileh.write(code_str)
|
||||||
|
fileh.flush()
|
||||||
|
|
||||||
|
import_file("python_code", fileh.name)
|
||||||
|
config = {"initialize": {"components": {"dummy_component": {"dummy_param": 1}}}}
|
||||||
|
nlp = English.from_config(config)
|
||||||
|
nlp.add_pipe("dummy_component")
|
||||||
|
nlp.initialize()
|
||||||
|
|
|
@ -14,6 +14,13 @@ def make_tempfile(mode="r"):
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def make_named_tempfile(mode="r", suffix=None):
|
||||||
|
f = tempfile.NamedTemporaryFile(mode=mode, suffix=suffix)
|
||||||
|
yield f
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
def get_batch(batch_size):
|
def get_batch(batch_size):
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
docs = []
|
docs = []
|
||||||
|
|
|
@ -1454,7 +1454,8 @@ def is_cython_func(func: Callable) -> bool:
|
||||||
if hasattr(func, attr): # function or class instance
|
if hasattr(func, attr): # function or class instance
|
||||||
return True
|
return True
|
||||||
# https://stackoverflow.com/a/55767059
|
# https://stackoverflow.com/a/55767059
|
||||||
if hasattr(func, "__qualname__") and hasattr(func, "__module__"): # method
|
if hasattr(func, "__qualname__") and hasattr(func, "__module__") \
|
||||||
|
and func.__module__ in sys.modules: # method
|
||||||
cls_func = vars(sys.modules[func.__module__])[func.__qualname__.split(".")[0]]
|
cls_func = vars(sys.modules[func.__module__])[func.__qualname__.split(".")[0]]
|
||||||
return hasattr(cls_func, attr)
|
return hasattr(cls_func, attr)
|
||||||
return False
|
return False
|
||||||
|
|
Loading…
Reference in New Issue