From 5738d373d5b1142cdea1cee4f44d75b454da935a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 21 May 2017 18:43:31 -0500 Subject: [PATCH] Add deprojectivize to pipeline --- spacy/language.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 2f14ea3de..0f38252f7 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -93,10 +93,12 @@ class BaseDefaults(object): factories = { 'make_doc': create_tokenizer, - 'token_vectors': lambda nlp, **cfg: TokenVectorEncoder(nlp.vocab, **cfg), - 'tags': lambda nlp, **cfg: NeuralTagger(nlp.vocab, **cfg), - 'dependencies': lambda nlp, **cfg: NeuralDependencyParser(nlp.vocab, **cfg), - 'entities': lambda nlp, **cfg: NeuralEntityRecognizer(nlp.vocab, **cfg), + 'token_vectors': lambda nlp, **cfg: [TokenVectorEncoder(nlp.vocab, **cfg)], + 'tags': lambda nlp, **cfg: [NeuralTagger(nlp.vocab, **cfg)], + 'dependencies': lambda nlp, **cfg: [ + NeuralDependencyParser(nlp.vocab, **cfg), + PseudoProjectivity.deprojectivize], + 'entities': lambda nlp, **cfg: [NeuralEntityRecognizer(nlp.vocab, **cfg)], } token_match = TOKEN_MATCH @@ -162,6 +164,13 @@ class Language(object): self.pipeline[i] = factory(self, **meta.get(entry, {})) else: self.pipeline = [] + flat_list = [] + for pipe in self.pipeline: + if isinstance(pipe, list): + flat_list.extend(pipe) + else: + flat_list.append(pipe) + self.pipeline = flat_list def __call__(self, text, **disabled): """'Apply the pipeline to some text. The text can span multiple sentences, @@ -207,6 +216,8 @@ class Language(object): tok2vec = self.pipeline[0] feats = tok2vec.doc2feats(docs) for proc in self.pipeline[1:]: + if not hasattr(proc, 'update'): + continue grads = {} tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop) d_tokvecses = proc.update((docs, tokvecses), golds, sgd=get_grads, drop=drop) @@ -326,7 +337,8 @@ class Language(object): if hasattr(proc, 'pipe'): docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size) else: - docs = (proc(doc) for doc in docs) + # Apply the function, but yield the doc + docs = _pipe(proc, docs) for doc in docs: yield doc @@ -402,3 +414,8 @@ class Language(object): if key not in exclude: setattr(self, key, value) return self + +def _pipe(func, docs): + for doc in docs: + func(doc) + yield doc