Trim flaky amp test (#15051)

This commit is contained in:
Carlos Mocholí 2022-10-10 13:49:37 +02:00 committed by GitHub
parent 5a3007cd6c
commit 69fee71f22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 13 deletions

View File

@ -69,32 +69,36 @@ class AMPTestModel(BoringModel):
@RunIf(min_torch="1.10")
@pytest.mark.parametrize(
"strategy",
[
None,
pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported on CPU currently")), # TODO
"ddp_spawn",
],
("strategy", "precision", "devices"),
(
("single_device", 16, 1),
("single_device", "bf16", 1),
("ddp_spawn", 16, 2),
("ddp_spawn", "bf16", 2),
),
)
@pytest.mark.parametrize("precision", [16, "bf16"])
@pytest.mark.parametrize("devices", [1, 2])
def test_amp_cpus(tmpdir, strategy, precision, devices):
"""Make sure combinations of AMP and strategies work if supported."""
tutils.reset_seed()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="cpu",
devices=devices,
max_epochs=1,
strategy=strategy,
precision=precision,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
logger=False,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
model = AMPTestModel()
trainer.fit(model)
trainer.test(model)
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
trainer.predict(model)
@RunIf(min_cuda_gpus=2, min_torch="1.10")