Updating docs and error message: half precision not available on CPU (#7384)

* Updating docs and error message to specify that half precission not available on CPU

* update messages

Co-authored-by: Martin Kristiansen <martinkristiansen@sixgill.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: jirka <jirka.borovec@seznam.cz>
This commit is contained in:
Martin Kristiansen 2021-05-06 05:05:50 -04:00 committed by GitHub
parent dea7a0230d
commit c3fc0313ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 4 deletions

View File

@ -1156,7 +1156,7 @@ precision
|
Double precision (64), full precision (32) or half precision (16).
Can be used on CPU, GPU or TPUs.
Can all be used on GPU or TPUs. Only double (64) and full precision (32) available on CPU.
If used on TPU will use torch.bfloat16 but tensor printing
will still show torch.float32.

View File

@ -27,9 +27,15 @@ class CPUAccelerator(Accelerator):
If AMP is used with CPU, or if the selected device is not CPU.
"""
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option")
raise MisconfigurationException(
" Mixed precision is currenty only supported with the AMP backend"
" and AMP + CPU is not supported. Please use a GPU option or"
" change precision setting."
)
if "cpu" not in str(self.root_device):
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead")
raise MisconfigurationException(
f"Device should be CPU, got {self.root_device} instead."
)
return super().setup(trainer, model)

View File

@ -18,7 +18,7 @@ def test_unsupported_precision_plugins():
accelerator = CPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
with pytest.raises(MisconfigurationException, match=r"AMP \+ CPU is not supported"):
accelerator.setup(trainer=trainer, model=model)