diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 9e39bf215a..a61a14a6aa 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -385,6 +385,7 @@ def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_m trainer.fit(model) +@mock.patch.dict(os.environ, {}) @mock.patch('torch.cuda.device_count', return_value=0) @pytest.mark.parametrize("ddp_plugin_class", [DDPPlugin, DDPSpawnPlugin]) def test_accelerator_choice_ddp_cpu_custom_plugin(_, ddp_plugin_class):