mirror of https://github.com/explosion/spaCy.git
Fix GPU usage in Language
This commit is contained in:
parent
711ad5edc4
commit
2713041571
|
@ -3,6 +3,10 @@ from __future__ import absolute_import, unicode_literals
|
|||
from contextlib import contextmanager
|
||||
import dill
|
||||
|
||||
import numpy
|
||||
from thinc.neural import Model
|
||||
from thinc.neural.ops import NumpyOps, CupyOps
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .vocab import Vocab
|
||||
from .tagger import Tagger
|
||||
|
@ -179,11 +183,16 @@ class Language(object):
|
|||
state = process.update(docs, golds,
|
||||
state=state,
|
||||
drop=drop,
|
||||
sgd=sgd)
|
||||
sgd=get_grads)
|
||||
else:
|
||||
process(docs, state=state)
|
||||
if sgd is not None:
|
||||
for key, (W, dW) in grads.items():
|
||||
# TODO: Unhack this when thinc improves
|
||||
if isinstance(W, numpy.ndarray):
|
||||
sgd.ops = NumpyOps()
|
||||
else:
|
||||
sgd.ops = CupyOps()
|
||||
sgd(W, dW, key=key)
|
||||
return state
|
||||
|
||||
|
@ -197,6 +206,10 @@ class Language(object):
|
|||
# Handle crossing dependencies
|
||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||
contexts = []
|
||||
if cfg.get('use_gpu'):
|
||||
Model.ops = CupyOps()
|
||||
Model.Ops = CupyOps
|
||||
print("Use GPU")
|
||||
for proc in self.pipeline:
|
||||
if hasattr(proc, 'begin_training'):
|
||||
context = proc.begin_training(gold_tuples,
|
||||
|
@ -205,6 +218,18 @@ class Language(object):
|
|||
trainer = Trainer(self, gold_tuples, **cfg)
|
||||
yield trainer, trainer.optimizer
|
||||
|
||||
@contextmanager
|
||||
def use_params(self, params, **cfg):
|
||||
contexts = [pipe.model.use_params(params) for pipe
|
||||
in self.pipeline if hasattr(pipe, 'model')
|
||||
and hasattr(pipe.model, 'use_params')]
|
||||
yield
|
||||
for context in contexts:
|
||||
try:
|
||||
next(context.gen)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
def pipe(self, texts, n_threads=2, batch_size=1000, **disabled):
|
||||
"""
|
||||
Process texts as a stream, and yield Doc objects in order.
|
||||
|
|
Loading…
Reference in New Issue