diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 138b14058d..f9275fcbd8 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1156,7 +1156,7 @@ precision | Double precision (64), full precision (32) or half precision (16). -Can be used on CPU, GPU or TPUs. +Can all be used on GPU or TPUs. Only double (64) and full precision (32) available on CPU. If used on TPU will use torch.bfloat16 but tensor printing will still show torch.float32. diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 458f058c27..0da70c8a5c 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -27,9 +27,15 @@ class CPUAccelerator(Accelerator): If AMP is used with CPU, or if the selected device is not CPU. """ if isinstance(self.precision_plugin, MixedPrecisionPlugin): - raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option") + raise MisconfigurationException( + " Mixed precision is currenty only supported with the AMP backend" + " and AMP + CPU is not supported. Please use a GPU option or" + " change precision setting." + ) if "cpu" not in str(self.root_device): - raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead") + raise MisconfigurationException( + f"Device should be CPU, got {self.root_device} instead." + ) return super().setup(trainer, model) diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 46379a9d10..c7d7f98ae9 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -18,7 +18,7 @@ def test_unsupported_precision_plugins(): accelerator = CPUAccelerator( training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin() ) - with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): + with pytest.raises(MisconfigurationException, match=r"AMP \+ CPU is not supported"): accelerator.setup(trainer=trainer, model=model)