diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index 5950178869..1a5c6d68d9 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -1,4 +1,6 @@ import os +from copy import deepcopy +from typing import Mapping from unittest import mock from unittest.mock import Mock @@ -18,6 +20,37 @@ if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS +class ModelWithAdamOptimizer(BoringModel): + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.layer.parameters(), lr=0.1) + return optimizer + + +class CheckModelRestore(ModelWithAdamOptimizer): + def __init__(self, old_model_state_dict, old_optimizer_states): + super().__init__() + self.old_model_state_dict = old_model_state_dict + self.old_optimizer_states = old_optimizer_states + + def on_train_start(self): + assert all( + self._is_equal(actual, expected) for actual, expected in zip(self.state_dict(), self.old_model_state_dict) + ) + + for optimizer, state in zip(self.trainer.optimizers, self.old_optimizer_states): + optimizer_state = self.trainer.strategy.optimizer_state(optimizer) + self._is_equal(optimizer_state, state) + + def _is_equal(self, a, b): + if isinstance(a, torch.Tensor): + return torch.allclose(a, b) + + if isinstance(a, Mapping): + return all(self._is_equal(a.get(k, None), b.get(k, None)) for k in b.keys()) + + return a == b + + @pytest.mark.parametrize("clip_val", [0, 10]) @RunIf(min_cuda_gpus=1, fairscale=True) @mock.patch("fairscale.optim.oss.OSS.clip_grad_norm") @@ -324,3 +357,33 @@ def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, s # Assert model parameters are identical after loading for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): assert torch.equal(trained_param.to("cpu"), loaded_param) + + +@RunIf(min_cuda_gpus=2, fairscale=True) +def test_ddp_sharded_strategy_fit_ckpt_path_downsize_gpus(tmpdir): + model = ModelWithAdamOptimizer() + trainer = Trainer( + strategy="ddp_sharded_spawn", + max_epochs=1, + limit_train_batches=1, + limit_val_batches=0, + accelerator="gpu", + devices=2, + ) + trainer.fit(model) + + checkpoint_path = trainer.checkpoint_callback.best_model_path + ckpt = torch.load(checkpoint_path) + old_model_state_dict = deepcopy(ckpt["state_dict"]) + old_optimizer_states = deepcopy(ckpt["optimizer_states"]) + + model = CheckModelRestore(old_model_state_dict, old_optimizer_states) + trainer = Trainer( + strategy="ddp_sharded_spawn", + max_epochs=2, + limit_train_batches=1, + limit_val_batches=0, + accelerator="gpu", + devices=1, + ) + trainer.fit(model, ckpt_path=checkpoint_path)