fix amp/apex misconfiguration error for cpu (#6107)

* fix weird test

* fix apex plugin test

* fix raise

* cpu test

* fix type

* add changelog
This commit is contained in:
Adrian Wälchli 2021-02-22 01:02:31 +01:00 committed by GitHub
parent 97b4b3ee68
commit ae6ce17598
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 99 deletions

View File

@ -24,13 +24,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070)
- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070))
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)
@ -40,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)
- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689))
- Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))
- Added support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959))
- Added `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))

View File

@ -7,7 +7,7 @@ class CPUAccelerator(Accelerator):
def setup(self, trainer, model):
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
MisconfigurationException("amp + cpu is not supported. Please use a GPU option")
raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option")
if "cpu" not in str(self.root_device):
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead")

View File

@ -0,0 +1,21 @@
from unittest.mock import Mock
import pytest
import torch
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
def test_unsupported_precision_plugins():
""" Test error messages are raised for unsupported precision plugins with CPU. """
trainer = Mock()
model = Mock()
accelerator = CPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cpu")),
precision_plugin=MixedPrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
accelerator.setup(trainer=trainer, model=model)

View File

@ -5,10 +5,8 @@ import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
@ -25,78 +23,21 @@ from tests.helpers.boring_model import BoringModel
)
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(
['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)],
['ddp_backend', 'gpus'],
[('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)],
)
def on_fit_start(tmpdir, ddp_backend, gpus, num_processes):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_plugin, NativeMixedPrecisionPlugin)
raise SystemExit()
def train():
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='native',
gpus=gpus,
num_processes=num_processes,
accelerator=ddp_backend,
callbacks=[CB()],
)
trainer.fit(model)
if ddp_backend == "ddp_cpu":
with pytest.raises(MisconfigurationException, match="MP is only available on GPU"):
train()
else:
with pytest.raises(SystemExit):
train()
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@mock.patch.dict(
os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
}
)
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(
['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)],
)
def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
def test_amp_choice_custom_ddp_cpu(device_count_mock, ddp_backend, gpus):
class MyNativeAMP(NativeMixedPrecisionPlugin):
pass
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_plugin, MyNativeAMP)
raise SystemExit()
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='native',
num_processes=num_processes,
accelerator=ddp_backend,
plugins=[MyNativeAMP()],
callbacks=[CB()],
)
with pytest.raises(SystemExit):
trainer.fit(model)
assert isinstance(trainer.precision_plugin, MyNativeAMP)
class GradientUnscaleBoringModel(BoringModel):

View File

@ -4,10 +4,8 @@ from unittest import mock
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE
from tests.helpers.boring_model import BoringModel
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@ -23,30 +21,19 @@ from tests.helpers.boring_model import BoringModel
)
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(
['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)],
['ddp_backend', 'gpus'],
[('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)],
)
def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
def test_amp_choice_default_ddp(mocked_device_count, ddp_backend, gpus):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
raise SystemExit()
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='apex',
gpus=gpus,
num_processes=num_processes,
accelerator=ddp_backend,
callbacks=[CB()],
)
with pytest.raises(SystemExit):
trainer.fit(model)
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@ -62,31 +49,20 @@ def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
)
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(
['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)],
['ddp_backend', 'gpus'],
[('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)],
)
def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
def test_amp_choice_custom_ddp(mocked_device_count, ddp_backend, gpus):
class MyApexPlugin(ApexMixedPrecisionPlugin):
pass
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_plugin, MyApexPlugin)
raise SystemExit()
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='apex',
gpus=gpus,
num_processes=num_processes,
accelerator=ddp_backend,
plugins=[MyApexPlugin(amp_level="O2")],
callbacks=[CB()],
)
with pytest.raises(SystemExit):
trainer.fit(model)
assert isinstance(trainer.precision_plugin, MyApexPlugin)