From 91c3d8ecb9a91c0ab197309f1c0bf12d53b1fb5d Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 11 Apr 2022 23:31:17 +0530 Subject: [PATCH] =?UTF-8?q?Raise=20`MisconfigurationException`=20when=20th?= =?UTF-8?q?e=20accelerator=20is=20available=20but=E2=80=A6=20(#12708)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 2 +- .../trainer/connectors/accelerator_connector.py | 11 +++++++++++ tests/accelerators/test_accelerator_connector.py | 16 +++++++--------- tests/accelerators/test_cpu.py | 2 +- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a31916976..9abc1ccf64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,7 +120,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Don't raise a warning when `nn.Module` is not saved under hparams ([#12669](https://github.com/PyTorchLightning/pytorch-lightning/pull/12669)) -- +- Raise `MisconfigurationException` when the accelerator is available but the user passes invalid `([]/0/"0")` values to the `devices` flag ([#12708](https://github.com/PyTorchLightning/pytorch-lightning/pull/12708)) ## [1.6.0] - 2022-03-29 diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 174051c44c..f4932a2ae8 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -412,6 +412,17 @@ class AcceleratorConnector: self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1 self._devices_flag = devices + if self._devices_flag in ([], 0, "0"): + accelerator_name = ( + self._accelerator_flag.__class__.__qualname__ + if isinstance(self._accelerator_flag, Accelerator) + else self._accelerator_flag + ) + raise MisconfigurationException( + f"`Trainer(devices={self._devices_flag!r})` value is not a valid input" + f" using {accelerator_name} accelerator." + ) + # TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed self._map_deprecated_devices_specific_info_to_accelerator_and_device_flag( devices, num_processes, gpus, ipus, tpu_cores diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 02c946e186..1d64705d1b 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -504,15 +504,6 @@ def test_accelerator_cpu(_): trainer = Trainer(accelerator="cpu", gpus=1) -@mock.patch("torch.cuda.is_available", return_value=False) -@pytest.mark.parametrize("devices", ["0", 0, []]) -def test_passing_zero_and_empty_list_to_devices_flag(_, devices): - with pytest.raises( - MisconfigurationException, match="can not run on your system since the accelerator is not available." - ): - Trainer(accelerator="gpu", devices=devices) - - @RunIf(min_gpus=1) def test_accelerator_gpu(): trainer = Trainer(accelerator="gpu", devices=1) @@ -1014,3 +1005,10 @@ def test_sync_batchnorm_set_in_custom_strategy(tmpdir): def test_plugin_only_one_instance_for_one_type(plugins, expected): with pytest.raises(MisconfigurationException, match=f"Received multiple values for {expected}"): Trainer(plugins=plugins) + + +@pytest.mark.parametrize("accelerator", ("cpu", "gpu", "tpu", "ipu")) +@pytest.mark.parametrize("devices", ("0", 0, [])) +def test_passing_zero_and_empty_list_to_devices_flag(accelerator, devices): + with pytest.raises(MisconfigurationException, match="value is not a valid input using"): + Trainer(accelerator=accelerator, devices=devices) diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 8abacd47af..ce9f6d6b21 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -69,7 +69,7 @@ def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup): func(model, ckpt_path=checkpoint_path) -@pytest.mark.parametrize("devices", ([3], -1, 0)) +@pytest.mark.parametrize("devices", ([3], -1)) def test_invalid_devices_with_cpu_accelerator(devices): """Test invalid device flag raises MisconfigurationException with CPUAccelerator.""" with pytest.raises(MisconfigurationException, match="should be an int > 0"):