Trim flaky amp test (#15051)
This commit is contained in:
parent
5a3007cd6c
commit
69fee71f22
|
@ -69,32 +69,36 @@ class AMPTestModel(BoringModel):
|
|||
|
||||
@RunIf(min_torch="1.10")
|
||||
@pytest.mark.parametrize(
|
||||
"strategy",
|
||||
[
|
||||
None,
|
||||
pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported on CPU currently")), # TODO
|
||||
"ddp_spawn",
|
||||
],
|
||||
("strategy", "precision", "devices"),
|
||||
(
|
||||
("single_device", 16, 1),
|
||||
("single_device", "bf16", 1),
|
||||
("ddp_spawn", 16, 2),
|
||||
("ddp_spawn", "bf16", 2),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("precision", [16, "bf16"])
|
||||
@pytest.mark.parametrize("devices", [1, 2])
|
||||
def test_amp_cpus(tmpdir, strategy, precision, devices):
|
||||
"""Make sure combinations of AMP and strategies work if supported."""
|
||||
tutils.reset_seed()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
accelerator="cpu",
|
||||
devices=devices,
|
||||
max_epochs=1,
|
||||
strategy=strategy,
|
||||
precision=precision,
|
||||
max_epochs=1,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
limit_test_batches=1,
|
||||
limit_predict_batches=1,
|
||||
logger=False,
|
||||
enable_checkpointing=False,
|
||||
enable_model_summary=False,
|
||||
enable_progress_bar=False,
|
||||
)
|
||||
|
||||
model = AMPTestModel()
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
|
||||
trainer.predict(model)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, min_torch="1.10")
|
||||
|
|
Loading…
Reference in New Issue