From 2479cde44611fadc1a8b0497fc32f791def4fb3b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 5 Jun 2017 13:13:07 +0200 Subject: [PATCH] Support disable keyword in Language.__init__ --- spacy/language.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index eefe3b9d4..106076d25 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -85,11 +85,13 @@ class BaseDefaults(object): return NeuralEntityRecognizer(nlp.vocab, **cfg) @classmethod - def create_pipeline(cls, nlp=None): + def create_pipeline(cls, nlp=None, disable=tuple()): meta = nlp.meta if nlp is not None else {} # Resolve strings, like "cnn", "lstm", etc pipeline = [] for entry in cls.pipeline: + if entry in disable or getattr(entry, 'name', entry) in disable: + continue factory = cls.Defaults.factories[entry] pipeline.append(factory(nlp, **meta.get(entry, {}))) return pipeline @@ -141,7 +143,8 @@ class Language(object): Defaults = BaseDefaults lang = None - def __init__(self, vocab=True, make_doc=True, pipeline=None, meta={}, **kwargs): + def __init__(self, vocab=True, make_doc=True, pipeline=None, meta={}, + disable=tuple(), **kwargs): """Initialise a Language object. vocab (Vocab): A `Vocab` object. If `True`, a vocab is created via @@ -151,12 +154,14 @@ class Language(object): pipeline (list): A list of annotation processes or IDs of annotation, processes, e.g. a `Tagger` object, or `'tagger'`. IDs are looked up in `Language.Defaults.factories`. + disable (list): A list of component names to exclude from the pipeline. + The disable list has priority over the pipeline list -- if the same + string occurs in both, the component is not loaded. meta (dict): Custom meta data for the Language class. Is written to by models to add model meta data. RETURNS (Language): The newly constructed object. """ self.meta = dict(meta) - if vocab is True: factory = self.Defaults.create_vocab vocab = factory(self, **meta.get('vocab', {})) @@ -166,9 +171,13 @@ class Language(object): make_doc = factory(self, **meta.get('tokenizer', {})) self.tokenizer = make_doc if pipeline is True: - self.pipeline = self.Defaults.create_pipeline(self) + self.pipeline = self.Defaults.create_pipeline(self, disable) elif pipeline: - self.pipeline = list(pipeline) + # Careful not to do getattr(p, 'name', None) here + # If we had disable=[None], we'd disable everything! + self.pipeline = [p for p in pipeline + if p not in disable + and getattr(p, 'name', p) not in disable] # Resolve strings, like "cnn", "lstm", etc for i, entry in enumerate(self.pipeline): if entry in self.Defaults.factories: