[test] Add checks for gpus=1 (#7105)

* update

* remove cluster env
This commit is contained in:
thomas chaton 2021-04-19 19:39:28 +01:00 committed by GitHub
parent 6b15ca95f0
commit 3cc0b2c063
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -466,13 +466,14 @@ def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str):
])
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize("gpus", [1, 2])
def test_accelerator_choice_multi_node_gpu(
mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin
mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin, gpus: int
):
trainer = Trainer(
accelerator=accelerator,
default_root_dir=tmpdir,
num_nodes=2,
gpus=2,
gpus=gpus,
)
assert isinstance(trainer.training_type_plugin, plugin)