diff --git a/spacy/language.py b/spacy/language.py index 228225404..1e4ae1474 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -145,7 +145,7 @@ class Language(object): else: self.pipeline = [] - def __call__(self, text, state=None, **disabled): + def __call__(self, text, **disabled): """ Apply the pipeline to some text. The text can span multiple sentences, and can contain arbtrary whitespace. Alignment into the original string @@ -153,7 +153,6 @@ class Language(object): Args: text (unicode): The text to be processed. - state: Arbitrary Returns: doc (Doc): A container for accessing the annotations. @@ -170,31 +169,28 @@ class Language(object): name = getattr(proc, 'name', None) if name in disabled and not disabled[name]: continue - state = proc(doc, state=state) + proc(doc) return doc - def update(self, docs, golds, state=None, drop=0., sgd=None): + def update(self, docs, golds, drop=0., sgd=None): grads = {} def get_grads(W, dW, key=None): grads[key] = (W, dW) - state = {} if state is None else state - for process in self.pipeline: - if hasattr(process, 'update'): - state = process.update(docs, golds, - state=state, - drop=drop, - sgd=get_grads) - else: - process(docs, state=state) - if sgd is not None: - for key, (W, dW) in grads.items(): - # TODO: Unhack this when thinc improves - if isinstance(W, numpy.ndarray): - sgd.ops = NumpyOps() - else: - sgd.ops = CupyOps() - sgd(W, dW, key=key) - return state + tok2vec = self.pipeline[0] + feats = tok2vec.doc2feats(docs) + for proc in self.pipeline[1:]: + tokvecs, bp_tokvecs = tok2vec.model.begin_update(feats, drop=drop) + grads = {} + d_tokvecs = proc.update((docs, tokvecs), golds, sgd=get_grads, drop=drop) + bp_tokvecs(d_tokvecs, sgd=get_grads) + if sgd is not None: + for key, (W, dW) in grads.items(): + # TODO: Unhack this when thinc improves + if isinstance(W, numpy.ndarray): + sgd.ops = NumpyOps() + else: + sgd.ops = CupyOps() + sgd(W, dW, key=key) @contextmanager def begin_training(self, gold_tuples, **cfg): @@ -248,18 +244,18 @@ class Language(object): parse (bool) entity (bool) """ - #stream = ((self.make_doc(text), None) for text in texts) - stream = ((doc, {}) for doc in texts) + #docs = (self.make_doc(text) for text in texts) + docs = texts for proc in self.pipeline: name = getattr(proc, 'name', None) if name in disabled and not disabled[name]: continue if hasattr(proc, 'pipe'): - stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size) + docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size) else: - stream = (proc(doc, state) for doc, state in stream) - for doc, state in stream: + docs = (proc(doc) for doc in docs) + for doc in docs: yield doc def to_disk(self, path, **exclude):