Add test to verify that lowering gpus on restart works with sharded spawn (#15317)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Rohit Gupta 2022-11-08 18:33:29 +05:30 committed by GitHub
parent f9a65731cd
commit 1cd66b6d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 63 additions and 0 deletions

View File

@ -1,4 +1,6 @@
import os
from copy import deepcopy
from typing import Mapping
from unittest import mock
from unittest.mock import Mock
@ -18,6 +20,37 @@ if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
class ModelWithAdamOptimizer(BoringModel):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.layer.parameters(), lr=0.1)
return optimizer
class CheckModelRestore(ModelWithAdamOptimizer):
def __init__(self, old_model_state_dict, old_optimizer_states):
super().__init__()
self.old_model_state_dict = old_model_state_dict
self.old_optimizer_states = old_optimizer_states
def on_train_start(self):
assert all(
self._is_equal(actual, expected) for actual, expected in zip(self.state_dict(), self.old_model_state_dict)
)
for optimizer, state in zip(self.trainer.optimizers, self.old_optimizer_states):
optimizer_state = self.trainer.strategy.optimizer_state(optimizer)
self._is_equal(optimizer_state, state)
def _is_equal(self, a, b):
if isinstance(a, torch.Tensor):
return torch.allclose(a, b)
if isinstance(a, Mapping):
return all(self._is_equal(a.get(k, None), b.get(k, None)) for k in b.keys())
return a == b
@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(min_cuda_gpus=1, fairscale=True)
@mock.patch("fairscale.optim.oss.OSS.clip_grad_norm")
@ -324,3 +357,33 @@ def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, s
# Assert model parameters are identical after loading
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param)
@RunIf(min_cuda_gpus=2, fairscale=True)
def test_ddp_sharded_strategy_fit_ckpt_path_downsize_gpus(tmpdir):
model = ModelWithAdamOptimizer()
trainer = Trainer(
strategy="ddp_sharded_spawn",
max_epochs=1,
limit_train_batches=1,
limit_val_batches=0,
accelerator="gpu",
devices=2,
)
trainer.fit(model)
checkpoint_path = trainer.checkpoint_callback.best_model_path
ckpt = torch.load(checkpoint_path)
old_model_state_dict = deepcopy(ckpt["state_dict"])
old_optimizer_states = deepcopy(ckpt["optimizer_states"])
model = CheckModelRestore(old_model_state_dict, old_optimizer_states)
trainer = Trainer(
strategy="ddp_sharded_spawn",
max_epochs=2,
limit_train_batches=1,
limit_val_batches=0,
accelerator="gpu",
devices=1,
)
trainer.fit(model, ckpt_path=checkpoint_path)