Fix use_params and pipe methods

This commit is contained in:
Matthew Honnibal 2017-05-18 08:30:59 -05:00
parent ca70b08661
commit c2c825127a
3 changed files with 62 additions and 31 deletions

View File

@ -220,13 +220,19 @@ class Language(object):
@contextmanager @contextmanager
def use_params(self, params, **cfg): def use_params(self, params, **cfg):
contexts = [pipe.model.use_params(params) for pipe contexts = [pipe.use_params(params) for pipe
in self.pipeline if hasattr(pipe, 'model') in self.pipeline if hasattr(pipe, 'use_params')]
and hasattr(pipe.model, 'use_params')] # TODO: Having trouble with contextlib
# Workaround: these aren't actually context managers atm.
for context in contexts:
try:
next(context)
except StopIteration:
pass
yield yield
for context in contexts: for context in contexts:
try: try:
next(context.gen) next(context)
except StopIteration: except StopIteration:
pass pass
@ -242,7 +248,8 @@ class Language(object):
parse (bool) parse (bool)
entity (bool) entity (bool)
""" """
stream = ((self.make_doc(text), None) for text in texts) #stream = ((self.make_doc(text), None) for text in texts)
stream = ((doc, {}) for doc in texts)
for proc in self.pipeline: for proc in self.pipeline:
name = getattr(proc, 'name', None) name = getattr(proc, 'name', None)
if name in disabled and not disabled[name]: if name in disabled and not disabled[name]:

View File

@ -61,8 +61,14 @@ class TokenVectorEncoder(object):
state['tokvecs'] = tokvecs state['tokvecs'] = tokvecs
return state return state
def pipe(self, docs, **kwargs): def pipe(self, stream, batch_size=128, n_threads=-1):
raise NotImplementedError for batch in cytoolz.partition_all(batch_size, stream):
docs, states = zip(*batch)
tokvecs = self.predict(docs)
self.set_annotations(docs, tokvecs)
for state in states:
state['tokvecs'] = tokvecs
yield from zip(docs, states)
def predict(self, docs): def predict(self, docs):
feats = self.doc2feats(docs) feats = self.doc2feats(docs)
@ -96,6 +102,10 @@ class TokenVectorEncoder(object):
if self.model is True: if self.model is True:
self.model = self.Model() self.model = self.Model()
def use_params(self, params):
with self.model.use_params(params):
yield
class NeuralTagger(object): class NeuralTagger(object):
name = 'nn_tagger' name = 'nn_tagger'
@ -112,11 +122,13 @@ class NeuralTagger(object):
return state return state
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for batch in cytoolz.partition_all(batch_size, batch): for batch in cytoolz.partition_all(batch_size, stream):
docs, tokvecs = zip(*batch) docs, states = zip(*batch)
tag_ids = self.predict(docs, tokvecs) tag_ids = self.predict(states[0]['tokvecs'])
self.set_annotations(docs, tag_ids) self.set_annotations(docs, tag_ids)
yield from docs for state in states:
state['tag_ids'] = tag_ids
yield from zip(docs, states)
def predict(self, tokvecs): def predict(self, tokvecs):
scores = self.model(tokvecs) scores = self.model(tokvecs)
@ -130,7 +142,7 @@ class NeuralTagger(object):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc
cdef int idx = 0 cdef int idx = 0
cdef int i, j cdef int i, j, tag_id
cdef Vocab vocab = self.vocab cdef Vocab vocab = self.vocab
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc_tag_ids = batch_tag_ids[idx:idx+len(doc)] doc_tag_ids = batch_tag_ids[idx:idx+len(doc)]
@ -147,7 +159,6 @@ class NeuralTagger(object):
self.model.nI = tokvecs.shape[1] self.model.nI = tokvecs.shape[1]
tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop) tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop)
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores) loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd) d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
@ -167,24 +178,33 @@ class NeuralTagger(object):
for tag in gold.tags: for tag in gold.tags:
correct[idx] = tag_index[tag] correct[idx] = tag_index[tag]
idx += 1 idx += 1
correct = self.model.ops.xp.array(correct) correct = self.model.ops.xp.array(correct, dtype='i')
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1]) d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
loss = (d_scores**2).sum() loss = (d_scores**2).sum()
d_scores = self.model.ops.asarray(d_scores) d_scores = self.model.ops.asarray(d_scores, dtype='f')
return loss, d_scores return float(loss), d_scores
def begin_training(self, gold_tuples, pipeline=None): def begin_training(self, gold_tuples, pipeline=None):
tag_map = dict(self.vocab.morphology.tag_map) orig_tag_map = dict(self.vocab.morphology.tag_map)
new_tag_map = {}
for raw_text, annots_brackets in gold_tuples: for raw_text, annots_brackets in gold_tuples:
for annots, brackets in annots_brackets: for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots ids, words, tags, heads, deps, ents = annots
for tag in tags: for tag in tags:
if tag not in tag_map: if tag in orig_tag_map:
tag_map[tag] = {POS: X} new_tag_map[tag] = orig_tag_map[tag]
else:
new_tag_map[tag] = {POS: X}
cdef Vocab vocab = self.vocab cdef Vocab vocab = self.vocab
vocab.morphology = Morphology(vocab.strings, tag_map, vocab.morphology = Morphology(vocab.strings, new_tag_map,
vocab.morphology.lemmatizer) vocab.morphology.lemmatizer)
self.model = Softmax(self.vocab.morphology.n_tags) self.model = Softmax(self.vocab.morphology.n_tags)
print("Tagging", self.model.nO, "tags")
def use_params(self, params):
with self.model.use_params(params):
yield
cdef class EntityRecognizer(LinearParser): cdef class EntityRecognizer(LinearParser):

View File

@ -7,6 +7,7 @@ from __future__ import unicode_literals, print_function
from collections import Counter from collections import Counter
import ujson import ujson
import contextlib
from libc.math cimport exp from libc.math cimport exp
cimport cython cimport cython
@ -297,18 +298,15 @@ cdef class Parser:
The number of threads with which to work on the buffer in parallel. The number of threads with which to work on the buffer in parallel.
Yields (Doc): Documents, in order. Yields (Doc): Documents, in order.
""" """
cdef StateClass state cdef StateClass parse_state
cdef Doc doc cdef Doc doc
queue = [] queue = []
for batch in cytoolz.partition_all(batch_size, stream): for batch in cytoolz.partition_all(batch_size, stream):
docs, tokvecs = zip(*batch) batch = list(batch)
states = self.parse_batch(docs, tokvecs) docs, states = zip(*batch)
for doc, state in zip(docs, states): parse_states = self.parse_batch(docs, states[0]['tokvecs'])
self.moves.finalize_state(state.c) self.set_annotations(docs, parse_states)
for i in range(doc.length): yield from zip(docs, states)
doc.c[i] = state.c._sent[i]
self.moves.finalize_doc(doc)
yield doc
def parse_batch(self, docs, tokvecs): def parse_batch(self, docs, tokvecs):
cuda_stream = get_cuda_stream() cuda_stream = get_cuda_stream()
@ -324,7 +322,7 @@ cdef class Parser:
scores = vec2scores(vectors) scores = vec2scores(vectors)
self.transition_batch(states, scores) self.transition_batch(states, scores)
todo = [st for st in states if not st.is_final()] todo = [st for st in states if not st.is_final()]
self.finish_batch(states, docs) return states
def update(self, docs, golds, state=None, drop=0., sgd=None): def update(self, docs, golds, state=None, drop=0., sgd=None):
assert state is not None assert state is not None
@ -437,7 +435,7 @@ cdef class Parser:
c_d_scores += d_scores.shape[1] c_d_scores += d_scores.shape[1]
return d_scores return d_scores
def finish_batch(self, states, docs): def set_annotations(self, docs, states):
cdef StateClass state cdef StateClass state
cdef Doc doc cdef Doc doc
for state, doc in zip(states, docs): for state, doc in zip(states, docs):
@ -465,6 +463,12 @@ cdef class Parser:
if self.model is True: if self.model is True:
self.model = self.Model(self.moves.n_moves, **cfg) self.model = self.Model(self.moves.n_moves, **cfg)
def use_params(self, params):
# Can't decorate cdef class :(. Workaround.
with self.model[0].use_params(params):
with self.model[1].use_params(params):
yield
def to_disk(self, path): def to_disk(self, path):
path = util.ensure_path(path) path = util.ensure_path(path)
with (path / 'model.bin').open('wb') as file_: with (path / 'model.bin').open('wb') as file_: