From 4d7f5468bb2033414a039fb2b1ad265ddd91e80a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 17 May 2016 16:55:42 +0200 Subject: [PATCH] * Change Language class to use a .pipeline attribute, instead of having the pipeline hard coded --- spacy/language.py | 49 +++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 236a0db03..f88f8e4c7 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -192,6 +192,13 @@ class Language(object): if matcher in (None, True): matcher = Matcher.from_package(package, self.vocab) self.matcher = matcher + self.pipeline = [ + self.tokenizer, + self.tagger, + self.entity, + self.parser, + self.matcher + ] def __reduce__(self): args = ( @@ -222,37 +229,29 @@ class Language(object): >>> tokens[0].orth_, tokens[0].head.tag_ ('An', 'NN') """ - tokens = self.tokenizer(text) - if self.tagger and tag: - self.tagger(tokens) - if self.matcher and entity: - self.matcher(tokens) - if self.parser and parse: - self.parser(tokens) + doc = self.pipeline[0](text) if self.entity and entity: # Add any of the entity labels already set, in case we don't have them. - for tok in tokens: - if tok.ent_type != 0: - self.entity.add_label(tok.ent_type) - self.entity(tokens) - return tokens + for token in doc: + if token.ent_type != 0: + self.entity.add_label(token.ent_type) + skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity} + for proc in self.pipeline[1:]: + if proc and not skip.get(proc): + proc(doc) + return doc def pipe(self, texts, tag=True, parse=True, entity=True, n_threads=2, batch_size=1000): - stream = self.tokenizer.pipe(texts, + skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity} + stream = self.pipeline[0].pipe(texts, n_threads=n_threads, batch_size=batch_size) - if self.tagger and tag: - stream = self.tagger.pipe(stream, - n_threads=n_threads, batch_size=batch_size) - if self.matcher and entity: - stream = self.matcher.pipe(stream, - n_threads=n_threads, batch_size=batch_size) - if self.parser and parse: - stream = self.parser.pipe(stream, - n_threads=n_threads, batch_size=batch_size) - if self.entity and entity: - stream = self.entity.pipe(stream, - n_threads=1, batch_size=batch_size) + for proc in self.pipeline[1:]: + if proc and not skip.get(proc): + if hasattr(proc, 'pipe'): + stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size) + else: + stream = (proc(item) for item in stream) for doc in stream: yield doc