Break the tokenization stage out of the pipeline into a function 'make_doc'. This allows all pipeline methods to have the same signature.

This commit is contained in:
Matthew Honnibal 2016-10-14 17:38:29 +02:00
parent 2cc515b2ed
commit 6d8cb515ac
1 changed files with 15 additions and 9 deletions

View File

@ -35,7 +35,6 @@ from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity
class BaseDefaults(object):
def __init__(self, lang, path):
self.path = path
@ -125,8 +124,11 @@ class BaseDefaults(object):
else:
return Matcher(vocab)
def MakeDoc(self, nlp, **cfg):
return nlp.tokenizer.__call__
def Pipeline(self, nlp, **cfg):
pipeline = [nlp.tokenizer]
pipeline = []
if nlp.tagger:
pipeline.append(nlp.tagger)
if nlp.parser:
@ -265,6 +267,7 @@ class Language(object):
matcher=True,
serializer=True,
vectors=True,
make_doc=True,
pipeline=True,
defaults=True,
data_dir=None):
@ -303,6 +306,11 @@ class Language(object):
self.entity = entity if entity is not True else defaults.Entity(self.vocab)
self.parser = parser if parser is not True else defaults.Parser(self.vocab)
self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab)
if make_doc in (None, True, False):
self.make_doc = defaults.MakeDoc(self)
else:
self.make_doc = make_doc
if pipeline in (None, False):
self.pipeline = []
elif pipeline is True:
@ -339,24 +347,22 @@ class Language(object):
>>> tokens[0].orth_, tokens[0].head.tag_
('An', 'NN')
"""
doc = self.pipeline[0](text)
doc = self.make_doc(text)
if self.entity and entity:
# Add any of the entity labels already set, in case we don't have them.
for token in doc:
if token.ent_type != 0:
self.entity.add_label(token.ent_type)
skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity}
for proc in self.pipeline[1:]:
for proc in self.pipeline:
if proc and not skip.get(proc):
proc(doc)
return doc
def pipe(self, texts, tag=True, parse=True, entity=True, n_threads=2,
batch_size=1000):
def pipe(self, texts, tag=True, parse=True, entity=True, n_threads=2, batch_size=1000):
skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity}
stream = self.pipeline[0].pipe(texts,
n_threads=n_threads, batch_size=batch_size)
for proc in self.pipeline[1:]:
stream = (self.make_doc(text) for text in texts)
for proc in self.pipeline:
if proc and not skip.get(proc):
if hasattr(proc, 'pipe'):
stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size)