diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index 8013369f5c..fe88ffea83 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -69,32 +69,36 @@ class AMPTestModel(BoringModel): @RunIf(min_torch="1.10") @pytest.mark.parametrize( - "strategy", - [ - None, - pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported on CPU currently")), # TODO - "ddp_spawn", - ], + ("strategy", "precision", "devices"), + ( + ("single_device", 16, 1), + ("single_device", "bf16", 1), + ("ddp_spawn", 16, 2), + ("ddp_spawn", "bf16", 2), + ), ) -@pytest.mark.parametrize("precision", [16, "bf16"]) -@pytest.mark.parametrize("devices", [1, 2]) def test_amp_cpus(tmpdir, strategy, precision, devices): """Make sure combinations of AMP and strategies work if supported.""" - tutils.reset_seed() - trainer = Trainer( default_root_dir=tmpdir, accelerator="cpu", devices=devices, - max_epochs=1, strategy=strategy, precision=precision, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + logger=False, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, ) - model = AMPTestModel() trainer.fit(model) trainer.test(model) - trainer.predict(model, DataLoader(RandomDataset(32, 64))) + trainer.predict(model) @RunIf(min_cuda_gpus=2, min_torch="1.10")