Fix GPU usage in Language

This commit is contained in:
Matthew Honnibal 2017-05-18 04:25:19 -05:00
parent 711ad5edc4
commit 2713041571
1 changed files with 26 additions and 1 deletions

View File

@ -3,6 +3,10 @@ from __future__ import absolute_import, unicode_literals
from contextlib import contextmanager from contextlib import contextmanager
import dill import dill
import numpy
from thinc.neural import Model
from thinc.neural.ops import NumpyOps, CupyOps
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
from .tagger import Tagger from .tagger import Tagger
@ -179,11 +183,16 @@ class Language(object):
state = process.update(docs, golds, state = process.update(docs, golds,
state=state, state=state,
drop=drop, drop=drop,
sgd=sgd) sgd=get_grads)
else: else:
process(docs, state=state) process(docs, state=state)
if sgd is not None: if sgd is not None:
for key, (W, dW) in grads.items(): for key, (W, dW) in grads.items():
# TODO: Unhack this when thinc improves
if isinstance(W, numpy.ndarray):
sgd.ops = NumpyOps()
else:
sgd.ops = CupyOps()
sgd(W, dW, key=key) sgd(W, dW, key=key)
return state return state
@ -197,6 +206,10 @@ class Language(object):
# Handle crossing dependencies # Handle crossing dependencies
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples) gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
contexts = [] contexts = []
if cfg.get('use_gpu'):
Model.ops = CupyOps()
Model.Ops = CupyOps
print("Use GPU")
for proc in self.pipeline: for proc in self.pipeline:
if hasattr(proc, 'begin_training'): if hasattr(proc, 'begin_training'):
context = proc.begin_training(gold_tuples, context = proc.begin_training(gold_tuples,
@ -205,6 +218,18 @@ class Language(object):
trainer = Trainer(self, gold_tuples, **cfg) trainer = Trainer(self, gold_tuples, **cfg)
yield trainer, trainer.optimizer yield trainer, trainer.optimizer
@contextmanager
def use_params(self, params, **cfg):
contexts = [pipe.model.use_params(params) for pipe
in self.pipeline if hasattr(pipe, 'model')
and hasattr(pipe.model, 'use_params')]
yield
for context in contexts:
try:
next(context.gen)
except StopIteration:
pass
def pipe(self, texts, n_threads=2, batch_size=1000, **disabled): def pipe(self, texts, n_threads=2, batch_size=1000, **disabled):
""" """
Process texts as a stream, and yield Doc objects in order. Process texts as a stream, and yield Doc objects in order.