mirror of https://github.com/explosion/spaCy.git
Fix use_params and pipe methods
This commit is contained in:
parent
ca70b08661
commit
c2c825127a
|
@ -220,13 +220,19 @@ class Language(object):
|
|||
|
||||
@contextmanager
|
||||
def use_params(self, params, **cfg):
|
||||
contexts = [pipe.model.use_params(params) for pipe
|
||||
in self.pipeline if hasattr(pipe, 'model')
|
||||
and hasattr(pipe.model, 'use_params')]
|
||||
contexts = [pipe.use_params(params) for pipe
|
||||
in self.pipeline if hasattr(pipe, 'use_params')]
|
||||
# TODO: Having trouble with contextlib
|
||||
# Workaround: these aren't actually context managers atm.
|
||||
for context in contexts:
|
||||
try:
|
||||
next(context)
|
||||
except StopIteration:
|
||||
pass
|
||||
yield
|
||||
for context in contexts:
|
||||
try:
|
||||
next(context.gen)
|
||||
next(context)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
@ -242,7 +248,8 @@ class Language(object):
|
|||
parse (bool)
|
||||
entity (bool)
|
||||
"""
|
||||
stream = ((self.make_doc(text), None) for text in texts)
|
||||
#stream = ((self.make_doc(text), None) for text in texts)
|
||||
stream = ((doc, {}) for doc in texts)
|
||||
for proc in self.pipeline:
|
||||
name = getattr(proc, 'name', None)
|
||||
if name in disabled and not disabled[name]:
|
||||
|
|
|
@ -61,8 +61,14 @@ class TokenVectorEncoder(object):
|
|||
state['tokvecs'] = tokvecs
|
||||
return state
|
||||
|
||||
def pipe(self, docs, **kwargs):
|
||||
raise NotImplementedError
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||
for batch in cytoolz.partition_all(batch_size, stream):
|
||||
docs, states = zip(*batch)
|
||||
tokvecs = self.predict(docs)
|
||||
self.set_annotations(docs, tokvecs)
|
||||
for state in states:
|
||||
state['tokvecs'] = tokvecs
|
||||
yield from zip(docs, states)
|
||||
|
||||
def predict(self, docs):
|
||||
feats = self.doc2feats(docs)
|
||||
|
@ -96,6 +102,10 @@ class TokenVectorEncoder(object):
|
|||
if self.model is True:
|
||||
self.model = self.Model()
|
||||
|
||||
def use_params(self, params):
|
||||
with self.model.use_params(params):
|
||||
yield
|
||||
|
||||
|
||||
class NeuralTagger(object):
|
||||
name = 'nn_tagger'
|
||||
|
@ -112,11 +122,13 @@ class NeuralTagger(object):
|
|||
return state
|
||||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||
for batch in cytoolz.partition_all(batch_size, batch):
|
||||
docs, tokvecs = zip(*batch)
|
||||
tag_ids = self.predict(docs, tokvecs)
|
||||
for batch in cytoolz.partition_all(batch_size, stream):
|
||||
docs, states = zip(*batch)
|
||||
tag_ids = self.predict(states[0]['tokvecs'])
|
||||
self.set_annotations(docs, tag_ids)
|
||||
yield from docs
|
||||
for state in states:
|
||||
state['tag_ids'] = tag_ids
|
||||
yield from zip(docs, states)
|
||||
|
||||
def predict(self, tokvecs):
|
||||
scores = self.model(tokvecs)
|
||||
|
@ -130,7 +142,7 @@ class NeuralTagger(object):
|
|||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef int idx = 0
|
||||
cdef int i, j
|
||||
cdef int i, j, tag_id
|
||||
cdef Vocab vocab = self.vocab
|
||||
for i, doc in enumerate(docs):
|
||||
doc_tag_ids = batch_tag_ids[idx:idx+len(doc)]
|
||||
|
@ -147,7 +159,6 @@ class NeuralTagger(object):
|
|||
self.model.nI = tokvecs.shape[1]
|
||||
|
||||
tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop)
|
||||
|
||||
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
||||
|
||||
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||
|
@ -167,24 +178,33 @@ class NeuralTagger(object):
|
|||
for tag in gold.tags:
|
||||
correct[idx] = tag_index[tag]
|
||||
idx += 1
|
||||
correct = self.model.ops.xp.array(correct)
|
||||
correct = self.model.ops.xp.array(correct, dtype='i')
|
||||
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||
loss = (d_scores**2).sum()
|
||||
d_scores = self.model.ops.asarray(d_scores)
|
||||
return loss, d_scores
|
||||
d_scores = self.model.ops.asarray(d_scores, dtype='f')
|
||||
return float(loss), d_scores
|
||||
|
||||
def begin_training(self, gold_tuples, pipeline=None):
|
||||
tag_map = dict(self.vocab.morphology.tag_map)
|
||||
orig_tag_map = dict(self.vocab.morphology.tag_map)
|
||||
new_tag_map = {}
|
||||
for raw_text, annots_brackets in gold_tuples:
|
||||
for annots, brackets in annots_brackets:
|
||||
ids, words, tags, heads, deps, ents = annots
|
||||
for tag in tags:
|
||||
if tag not in tag_map:
|
||||
tag_map[tag] = {POS: X}
|
||||
if tag in orig_tag_map:
|
||||
new_tag_map[tag] = orig_tag_map[tag]
|
||||
else:
|
||||
new_tag_map[tag] = {POS: X}
|
||||
cdef Vocab vocab = self.vocab
|
||||
vocab.morphology = Morphology(vocab.strings, tag_map,
|
||||
vocab.morphology = Morphology(vocab.strings, new_tag_map,
|
||||
vocab.morphology.lemmatizer)
|
||||
self.model = Softmax(self.vocab.morphology.n_tags)
|
||||
print("Tagging", self.model.nO, "tags")
|
||||
|
||||
def use_params(self, params):
|
||||
with self.model.use_params(params):
|
||||
yield
|
||||
|
||||
|
||||
|
||||
cdef class EntityRecognizer(LinearParser):
|
||||
|
|
|
@ -7,6 +7,7 @@ from __future__ import unicode_literals, print_function
|
|||
|
||||
from collections import Counter
|
||||
import ujson
|
||||
import contextlib
|
||||
|
||||
from libc.math cimport exp
|
||||
cimport cython
|
||||
|
@ -297,18 +298,15 @@ cdef class Parser:
|
|||
The number of threads with which to work on the buffer in parallel.
|
||||
Yields (Doc): Documents, in order.
|
||||
"""
|
||||
cdef StateClass state
|
||||
cdef StateClass parse_state
|
||||
cdef Doc doc
|
||||
queue = []
|
||||
for batch in cytoolz.partition_all(batch_size, stream):
|
||||
docs, tokvecs = zip(*batch)
|
||||
states = self.parse_batch(docs, tokvecs)
|
||||
for doc, state in zip(docs, states):
|
||||
self.moves.finalize_state(state.c)
|
||||
for i in range(doc.length):
|
||||
doc.c[i] = state.c._sent[i]
|
||||
self.moves.finalize_doc(doc)
|
||||
yield doc
|
||||
batch = list(batch)
|
||||
docs, states = zip(*batch)
|
||||
parse_states = self.parse_batch(docs, states[0]['tokvecs'])
|
||||
self.set_annotations(docs, parse_states)
|
||||
yield from zip(docs, states)
|
||||
|
||||
def parse_batch(self, docs, tokvecs):
|
||||
cuda_stream = get_cuda_stream()
|
||||
|
@ -324,7 +322,7 @@ cdef class Parser:
|
|||
scores = vec2scores(vectors)
|
||||
self.transition_batch(states, scores)
|
||||
todo = [st for st in states if not st.is_final()]
|
||||
self.finish_batch(states, docs)
|
||||
return states
|
||||
|
||||
def update(self, docs, golds, state=None, drop=0., sgd=None):
|
||||
assert state is not None
|
||||
|
@ -437,7 +435,7 @@ cdef class Parser:
|
|||
c_d_scores += d_scores.shape[1]
|
||||
return d_scores
|
||||
|
||||
def finish_batch(self, states, docs):
|
||||
def set_annotations(self, docs, states):
|
||||
cdef StateClass state
|
||||
cdef Doc doc
|
||||
for state, doc in zip(states, docs):
|
||||
|
@ -465,6 +463,12 @@ cdef class Parser:
|
|||
if self.model is True:
|
||||
self.model = self.Model(self.moves.n_moves, **cfg)
|
||||
|
||||
def use_params(self, params):
|
||||
# Can't decorate cdef class :(. Workaround.
|
||||
with self.model[0].use_params(params):
|
||||
with self.model[1].use_params(params):
|
||||
yield
|
||||
|
||||
def to_disk(self, path):
|
||||
path = util.ensure_path(path)
|
||||
with (path / 'model.bin').open('wb') as file_:
|
||||
|
|
Loading…
Reference in New Issue