mirror of https://github.com/explosion/spaCy.git
Fix `test_{prefer,require}_gpu` (#11390)
* Fix `test_{prefer,require}_gpu` These tests assumed that GPUs are only supported with CuPy, but since Thinc 8.1 we also support Metal Performance Shaders. * test_misc: arrange thinc imports to be together
This commit is contained in:
parent
5ae63b1fbd
commit
3f4b4b7b4f
|
@ -10,7 +10,8 @@ from spacy.ml._precomputable_affine import _backprop_precomputable_affine_paddin
|
|||
from spacy.util import dot_to_object, SimpleFrozenList, import_file
|
||||
from spacy.util import to_ternary_int
|
||||
from thinc.api import Config, Optimizer, ConfigValidationError
|
||||
from thinc.api import set_current_ops
|
||||
from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps
|
||||
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
|
||||
from spacy.training.batchers import minibatch_by_words
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.nl import Dutch
|
||||
|
@ -18,7 +19,6 @@ from spacy.language import DEFAULT_CONFIG_PATH
|
|||
from spacy.schemas import ConfigSchemaTraining, TokenPattern, TokenPatternSchema
|
||||
from pydantic import ValidationError
|
||||
|
||||
from thinc.api import get_current_ops, NumpyOps, CupyOps
|
||||
|
||||
from .util import get_random_doc, make_tempdir
|
||||
|
||||
|
@ -111,26 +111,25 @@ def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
|
|||
|
||||
def test_prefer_gpu():
|
||||
current_ops = get_current_ops()
|
||||
try:
|
||||
import cupy # noqa: F401
|
||||
|
||||
prefer_gpu()
|
||||
if has_cupy_gpu:
|
||||
assert prefer_gpu()
|
||||
assert isinstance(get_current_ops(), CupyOps)
|
||||
except ImportError:
|
||||
elif has_torch_mps_gpu:
|
||||
assert prefer_gpu()
|
||||
assert isinstance(get_current_ops(), MPSOps)
|
||||
else:
|
||||
assert not prefer_gpu()
|
||||
set_current_ops(current_ops)
|
||||
|
||||
|
||||
def test_require_gpu():
|
||||
current_ops = get_current_ops()
|
||||
try:
|
||||
import cupy # noqa: F401
|
||||
|
||||
if has_cupy_gpu:
|
||||
require_gpu()
|
||||
assert isinstance(get_current_ops(), CupyOps)
|
||||
except ImportError:
|
||||
with pytest.raises(ValueError):
|
||||
require_gpu()
|
||||
elif has_torch_mps_gpu:
|
||||
require_gpu()
|
||||
assert isinstance(get_current_ops(), MPSOps)
|
||||
set_current_ops(current_ops)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue