mirror of https://github.com/explosion/spaCy.git
Add deprojectivize to pipeline
This commit is contained in:
parent
1b5fa68996
commit
5738d373d5
|
@ -93,10 +93,12 @@ class BaseDefaults(object):
|
|||
|
||||
factories = {
|
||||
'make_doc': create_tokenizer,
|
||||
'token_vectors': lambda nlp, **cfg: TokenVectorEncoder(nlp.vocab, **cfg),
|
||||
'tags': lambda nlp, **cfg: NeuralTagger(nlp.vocab, **cfg),
|
||||
'dependencies': lambda nlp, **cfg: NeuralDependencyParser(nlp.vocab, **cfg),
|
||||
'entities': lambda nlp, **cfg: NeuralEntityRecognizer(nlp.vocab, **cfg),
|
||||
'token_vectors': lambda nlp, **cfg: [TokenVectorEncoder(nlp.vocab, **cfg)],
|
||||
'tags': lambda nlp, **cfg: [NeuralTagger(nlp.vocab, **cfg)],
|
||||
'dependencies': lambda nlp, **cfg: [
|
||||
NeuralDependencyParser(nlp.vocab, **cfg),
|
||||
PseudoProjectivity.deprojectivize],
|
||||
'entities': lambda nlp, **cfg: [NeuralEntityRecognizer(nlp.vocab, **cfg)],
|
||||
}
|
||||
|
||||
token_match = TOKEN_MATCH
|
||||
|
@ -162,6 +164,13 @@ class Language(object):
|
|||
self.pipeline[i] = factory(self, **meta.get(entry, {}))
|
||||
else:
|
||||
self.pipeline = []
|
||||
flat_list = []
|
||||
for pipe in self.pipeline:
|
||||
if isinstance(pipe, list):
|
||||
flat_list.extend(pipe)
|
||||
else:
|
||||
flat_list.append(pipe)
|
||||
self.pipeline = flat_list
|
||||
|
||||
def __call__(self, text, **disabled):
|
||||
"""'Apply the pipeline to some text. The text can span multiple sentences,
|
||||
|
@ -207,6 +216,8 @@ class Language(object):
|
|||
tok2vec = self.pipeline[0]
|
||||
feats = tok2vec.doc2feats(docs)
|
||||
for proc in self.pipeline[1:]:
|
||||
if not hasattr(proc, 'update'):
|
||||
continue
|
||||
grads = {}
|
||||
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
|
||||
d_tokvecses = proc.update((docs, tokvecses), golds, sgd=get_grads, drop=drop)
|
||||
|
@ -326,7 +337,8 @@ class Language(object):
|
|||
if hasattr(proc, 'pipe'):
|
||||
docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size)
|
||||
else:
|
||||
docs = (proc(doc) for doc in docs)
|
||||
# Apply the function, but yield the doc
|
||||
docs = _pipe(proc, docs)
|
||||
for doc in docs:
|
||||
yield doc
|
||||
|
||||
|
@ -402,3 +414,8 @@ class Language(object):
|
|||
if key not in exclude:
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
def _pipe(func, docs):
|
||||
for doc in docs:
|
||||
func(doc)
|
||||
yield doc
|
||||
|
|
Loading…
Reference in New Issue