This commit is contained in:
awaelchli 2024-06-23 18:35:54 +02:00
parent dc18138032
commit d782e1fbb3
1 changed files with 18 additions and 18 deletions

View File

@ -345,7 +345,7 @@ class LoggingSyncDistModel(BoringModel):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("devices", "accelerator"), ("devices", "accelerator"),
[ [
# (1, "cpu"), (1, "cpu"),
(2, "cpu"), (2, "cpu"),
pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)), pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)),
], ],
@ -368,23 +368,23 @@ def test_logging_sync_dist_true(tmp_path, devices, accelerator):
) )
trainer.fit(model) trainer.fit(model)
# total = fake_result * devices + 1 total = fake_result * devices + 1
# metrics = trainer.callback_metrics metrics = trainer.callback_metrics
# assert metrics["foo"] == total if use_multiple_devices else fake_result assert metrics["foo"] == total if use_multiple_devices else fake_result
# assert metrics["foo_2"] == 2 * devices assert metrics["foo_2"] == 2 * devices
# assert metrics["foo_3"] == 2 assert metrics["foo_3"] == 2
# assert metrics["foo_4"] == total / devices if use_multiple_devices else 1 assert metrics["foo_4"] == total / devices if use_multiple_devices else 1
# assert metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2 assert metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2
# assert metrics["foo_6"] == (0 + 1 + 1 + 2 + 2 + 3) if use_multiple_devices else fake_result * 3 * 2 assert metrics["foo_6"] == (0 + 1 + 1 + 2 + 2 + 3) if use_multiple_devices else fake_result * 3 * 2
# assert metrics["foo_7"] == 2 * devices * 3 assert metrics["foo_7"] == 2 * devices * 3
# assert metrics["foo_8"] == 2 assert metrics["foo_8"] == 2
# assert metrics["foo_9"] == (fake_result * 2 + 1) / devices if use_multiple_devices else fake_result assert metrics["foo_9"] == (fake_result * 2 + 1) / devices if use_multiple_devices else fake_result
# assert metrics["foo_10"] == 2 assert metrics["foo_10"] == 2
# assert metrics["foo_11_step"] == (2 + 3) / 2 if use_multiple_devices else fake_result * 2 assert metrics["foo_11_step"] == (2 + 3) / 2 if use_multiple_devices else fake_result * 2
# assert metrics["foo_11"] == (0 + 1 + 1 + 2 + 2 + 3) / (devices * 3) if use_multiple_devices else fake_result assert metrics["foo_11"] == (0 + 1 + 1 + 2 + 2 + 3) / (devices * 3) if use_multiple_devices else fake_result
# assert metrics["bar"] == fake_result * 3 * devices assert metrics["bar"] == fake_result * 3 * devices
# assert metrics["bar_2"] == fake_result assert metrics["bar_2"] == fake_result
# assert metrics["bar_3"] == 2 + int(use_multiple_devices) assert metrics["bar_3"] == 2 + int(use_multiple_devices)
@RunIf(min_cuda_gpus=2, standalone=True) @RunIf(min_cuda_gpus=2, standalone=True)