mirror of https://github.com/explosion/spaCy.git
Add Language.disable_pipes()
This commit is contained in:
parent
d9bb1e5de8
commit
e70f80f29e
|
@ -1,6 +1,7 @@
|
|||
# coding: utf8
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
from contextlib import contextmanager
|
||||
import copy
|
||||
|
||||
from thinc.neural import Model
|
||||
from thinc.neural.optimizers import Adam
|
||||
|
@ -329,6 +330,29 @@ class Language(object):
|
|||
doc = proc(doc)
|
||||
return doc
|
||||
|
||||
def disable_pipes(self, *names):
|
||||
'''Disable one or more pipeline components.
|
||||
|
||||
If used as a context manager, the pipeline will be restored to the initial
|
||||
state at the end of the block. Otherwise, a DisabledPipes object is
|
||||
returned, that has a `.restore()` method you can use to undo your
|
||||
changes.
|
||||
|
||||
EXAMPLE:
|
||||
|
||||
>>> nlp.add_pipe('parser')
|
||||
>>> nlp.add_pipe('tagger')
|
||||
>>> with nlp.disable_pipes('parser', 'tagger'):
|
||||
>>> assert not nlp.has_pipe('parser')
|
||||
>>> assert nlp.has_pipe('parser')
|
||||
>>> disabled = nlp.disable_pipes('parser')
|
||||
>>> assert len(disabled) == 1
|
||||
>>> assert not nlp.has_pipe('parser')
|
||||
>>> disabled.restore()
|
||||
>>> assert nlp.has_pipe('parser')
|
||||
'''
|
||||
return DisabledPipes(self, *names)
|
||||
|
||||
def make_doc(self, text):
|
||||
return self.tokenizer(text)
|
||||
|
||||
|
@ -655,6 +679,42 @@ class Language(object):
|
|||
return self
|
||||
|
||||
|
||||
class DisabledPipes(list):
|
||||
'''Manager for temporary pipeline disabling.'''
|
||||
def __init__(self, nlp, *names):
|
||||
self.nlp = nlp
|
||||
self.names = names
|
||||
# Important! Not deep copy -- we just want the container (but we also
|
||||
# want to support people providing arbitrarily typed nlp.pipeline
|
||||
# objects.)
|
||||
self.original_pipeline = copy.copy(nlp.pipeline)
|
||||
list.__init__(self)
|
||||
self.extend(nlp.remove_pipe(name) for name in names)
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.restore()
|
||||
|
||||
def restore(self):
|
||||
'''Restore the pipeline to its state when DisabledPipes was created.'''
|
||||
current, self.nlp.pipeline = self.nlp.pipeline, self.original_pipeline
|
||||
unexpected = [name for name in current if not self.nlp.has_pipe(name)]
|
||||
if unexpected:
|
||||
# Don't change the pipeline if we're raising an error.
|
||||
self.nlp.pipeline = current
|
||||
msg = (
|
||||
"Some current components would be lost when restoring "
|
||||
"previous pipeline state. If you added components after "
|
||||
"calling nlp.disable_pipes(), you should remove them "
|
||||
"explicitly with nlp.remove_pipe() before the pipeline is "
|
||||
"restore. Names of the new components: %s"
|
||||
)
|
||||
raise ValueError(msg % unexpected)
|
||||
self[:] = []
|
||||
|
||||
|
||||
def unpickle_language(vocab, meta, bytes_data):
|
||||
lang = Language(vocab=vocab)
|
||||
lang.from_bytes(bytes_data)
|
||||
|
|
|
@ -82,3 +82,21 @@ def test_remove_pipe(nlp, name):
|
|||
assert not len(nlp.pipeline)
|
||||
assert removed_name == name
|
||||
assert removed_component == new_pipe
|
||||
|
||||
|
||||
@pytest.mark.parametrize('name', ['my_component'])
|
||||
def test_disable_pipes_method(nlp, name):
|
||||
nlp.add_pipe(new_pipe, name=name)
|
||||
assert nlp.has_pipe(name)
|
||||
disabled = nlp.disable_pipes(name)
|
||||
assert not nlp.has_pipe(name)
|
||||
disabled.restore()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('name', ['my_component'])
|
||||
def test_disable_pipes_context(nlp, name):
|
||||
nlp.add_pipe(new_pipe, name=name)
|
||||
assert nlp.has_pipe(name)
|
||||
with nlp.disable_pipes(name):
|
||||
assert not nlp.has_pipe(name)
|
||||
assert nlp.has_pipe(name)
|
||||
|
|
Loading…
Reference in New Issue