From 3cc0b2c063e7d9f172f8060a626f883cca1bae93 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 19 Apr 2021 19:39:28 +0100 Subject: [PATCH] [test] Add checks for gpus=1 (#7105) * update * remove cluster env --- tests/accelerators/test_accelerator_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index de927aa5fd..a57fbb4afc 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -466,13 +466,14 @@ def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str): ]) @mock.patch('torch.cuda.is_available', return_value=True) @mock.patch('torch.cuda.device_count', return_value=2) +@pytest.mark.parametrize("gpus", [1, 2]) def test_accelerator_choice_multi_node_gpu( - mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin + mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin, gpus: int ): trainer = Trainer( accelerator=accelerator, default_root_dir=tmpdir, num_nodes=2, - gpus=2, + gpus=gpus, ) assert isinstance(trainer.training_type_plugin, plugin)