diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index de927aa5fd..a57fbb4afc 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -466,13 +466,14 @@ def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str): ]) @mock.patch('torch.cuda.is_available', return_value=True) @mock.patch('torch.cuda.device_count', return_value=2) +@pytest.mark.parametrize("gpus", [1, 2]) def test_accelerator_choice_multi_node_gpu( - mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin + mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin, gpus: int ): trainer = Trainer( accelerator=accelerator, default_root_dir=tmpdir, num_nodes=2, - gpus=2, + gpus=gpus, ) assert isinstance(trainer.training_type_plugin, plugin)