parent
6b15ca95f0
commit
3cc0b2c063
|
@ -466,13 +466,14 @@ def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str):
|
|||
])
|
||||
@mock.patch('torch.cuda.is_available', return_value=True)
|
||||
@mock.patch('torch.cuda.device_count', return_value=2)
|
||||
@pytest.mark.parametrize("gpus", [1, 2])
|
||||
def test_accelerator_choice_multi_node_gpu(
|
||||
mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin
|
||||
mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin, gpus: int
|
||||
):
|
||||
trainer = Trainer(
|
||||
accelerator=accelerator,
|
||||
default_root_dir=tmpdir,
|
||||
num_nodes=2,
|
||||
gpus=2,
|
||||
gpus=gpus,
|
||||
)
|
||||
assert isinstance(trainer.training_type_plugin, plugin)
|
||||
|
|
Loading…
Reference in New Issue