Raise `MisconfigurationException` when the accelerator is available but… (#12708)
This commit is contained in:
parent
a86ef06b3b
commit
91c3d8ecb9
|
@ -120,7 +120,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Don't raise a warning when `nn.Module` is not saved under hparams ([#12669](https://github.com/PyTorchLightning/pytorch-lightning/pull/12669))
|
||||
|
||||
|
||||
-
|
||||
- Raise `MisconfigurationException` when the accelerator is available but the user passes invalid `([]/0/"0")` values to the `devices` flag ([#12708](https://github.com/PyTorchLightning/pytorch-lightning/pull/12708))
|
||||
|
||||
|
||||
## [1.6.0] - 2022-03-29
|
||||
|
|
|
@ -412,6 +412,17 @@ class AcceleratorConnector:
|
|||
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
|
||||
self._devices_flag = devices
|
||||
|
||||
if self._devices_flag in ([], 0, "0"):
|
||||
accelerator_name = (
|
||||
self._accelerator_flag.__class__.__qualname__
|
||||
if isinstance(self._accelerator_flag, Accelerator)
|
||||
else self._accelerator_flag
|
||||
)
|
||||
raise MisconfigurationException(
|
||||
f"`Trainer(devices={self._devices_flag!r})` value is not a valid input"
|
||||
f" using {accelerator_name} accelerator."
|
||||
)
|
||||
|
||||
# TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed
|
||||
self._map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
|
||||
devices, num_processes, gpus, ipus, tpu_cores
|
||||
|
|
|
@ -504,15 +504,6 @@ def test_accelerator_cpu(_):
|
|||
trainer = Trainer(accelerator="cpu", gpus=1)
|
||||
|
||||
|
||||
@mock.patch("torch.cuda.is_available", return_value=False)
|
||||
@pytest.mark.parametrize("devices", ["0", 0, []])
|
||||
def test_passing_zero_and_empty_list_to_devices_flag(_, devices):
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="can not run on your system since the accelerator is not available."
|
||||
):
|
||||
Trainer(accelerator="gpu", devices=devices)
|
||||
|
||||
|
||||
@RunIf(min_gpus=1)
|
||||
def test_accelerator_gpu():
|
||||
trainer = Trainer(accelerator="gpu", devices=1)
|
||||
|
@ -1014,3 +1005,10 @@ def test_sync_batchnorm_set_in_custom_strategy(tmpdir):
|
|||
def test_plugin_only_one_instance_for_one_type(plugins, expected):
|
||||
with pytest.raises(MisconfigurationException, match=f"Received multiple values for {expected}"):
|
||||
Trainer(plugins=plugins)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("accelerator", ("cpu", "gpu", "tpu", "ipu"))
|
||||
@pytest.mark.parametrize("devices", ("0", 0, []))
|
||||
def test_passing_zero_and_empty_list_to_devices_flag(accelerator, devices):
|
||||
with pytest.raises(MisconfigurationException, match="value is not a valid input using"):
|
||||
Trainer(accelerator=accelerator, devices=devices)
|
||||
|
|
|
@ -69,7 +69,7 @@ def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup):
|
|||
func(model, ckpt_path=checkpoint_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("devices", ([3], -1, 0))
|
||||
@pytest.mark.parametrize("devices", ([3], -1))
|
||||
def test_invalid_devices_with_cpu_accelerator(devices):
|
||||
"""Test invalid device flag raises MisconfigurationException with CPUAccelerator."""
|
||||
with pytest.raises(MisconfigurationException, match="should be an int > 0"):
|
||||
|
|
Loading…
Reference in New Issue