diff --git a/CHANGELOG.md b/CHANGELOG.md index 602f1724a1..b265dc0999 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -402,6 +402,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `Timer.on_save_checkpoint` and `Timer.on_load_checkpoint` in favor of `Timer.state_dict` and `Timer.load_state_dict` ([#11887](https://github.com/PyTorchLightning/pytorch-lightning/pull/11887)) +- Replaced PostLocalSGDOptimizer with a dedicated model averaging component ([#12378](https://github.com/PyTorchLightning/pytorch-lightning/pull/12378)) + + ### Deprecated - Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141)) diff --git a/docs/source/advanced/model_parallel.rst b/docs/source/advanced/model_parallel.rst index 0f9192b239..3b45bd376b 100644 --- a/docs/source/advanced/model_parallel.rst +++ b/docs/source/advanced/model_parallel.rst @@ -749,10 +749,7 @@ Enable `FP16 Compress Hook for multi-node throughput improvement =1.10.0 + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.strategies import DDPStrategy + from torch.distributed.algorithms.ddp_comm_hooks import post_localSGD_hook as post_localSGD + + model = MyModel() + trainer = Trainer( + gpus=4, + strategy=DDPStrategy( + ddp_comm_state=post_localSGD.PostLocalSGDState( + process_group=None, + subgroup=None, + start_localSGD_iter=8, + ), + ddp_comm_hook=post_localSGD.post_localSGD_hook, + model_averaging_period=4, + ), + ) + trainer.fit(model) + DDP Static Graph """""""""""""""" diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index ad67467d02..146dd6a31d 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -18,12 +18,13 @@ import signal import tempfile import time from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.distributed from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim.optimizer import Optimizer import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer @@ -59,6 +60,8 @@ if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook +if _TORCH_GREATER_EQUAL_1_10: + from torch.distributed.algorithms.model_averaging.averagers import ModelAverager log = logging.getLogger(__name__) @@ -97,6 +100,7 @@ class DDPStrategy(ParallelStrategy): self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper self._model_averaging_period = model_averaging_period + self._model_averager: Optional[ModelAverager] = None self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None self._rank_0_will_call_children_scripts: bool = False @@ -223,23 +227,18 @@ class DDPStrategy(ParallelStrategy): import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): - self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) + self._enable_model_averaging() - def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): + def _enable_model_averaging(self) -> None: + # Only called when PyTorch version >= 1.10 log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") - optimizers = self.optimizers if self._model_averaging_period is None: raise ValueError( "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." ) - if _TORCH_GREATER_EQUAL_1_10: - if not _IS_WINDOWS: - from torch.distributed.optim import DistributedOptimizer - import torch.distributed.algorithms.model_averaging.averagers as averagers - from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer + from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer - averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps) - for x, optimizer in enumerate(optimizers): + for optimizer in self.optimizers: if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer @@ -248,24 +247,46 @@ class DDPStrategy(ParallelStrategy): is_distributed_optimizer or isinstance(optimizer, ZeroRedundancyOptimizer) or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)) + or isinstance(optimizer, PostLocalSGDOptimizer) ): raise ValueError( - f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer." + f"Currently model averaging cannot work with a distributed optimizer of type " + f"{optimizer.__class__.__name__}." ) - if isinstance(optimizer, PostLocalSGDOptimizer): - continue + self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager( + period=self._model_averaging_period, warmup_steps=self._ddp_comm_state.start_localSGD_iter + ) - optim_class = type(optimizer) - post_localSGD_optimizer = PostLocalSGDOptimizer( - params=optimizer.param_groups, - optimizer_class=optim_class, - averager=averager, - **optimizer.defaults, - ) - optimizers[x] = post_localSGD_optimizer - del optimizer - self.optimizers = optimizers + def optimizer_step( + self, + optimizer: Optimizer, + opt_idx: int, + closure: Callable[[], Any], + model: Optional[Union["pl.LightningModule", Module]] = None, + **kwargs: Any, + ) -> Any: + """Performs the actual optimizer step. + + Args: + optimizer: the optimizer performing the step + opt_idx: index of the current optimizer + closure: closure calculating the loss value + model: reference to the model, optionally defining optimizer step related hooks + **kwargs: Any extra arguments to ``optimizer.step`` + """ + optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs) + + if not _TORCH_GREATER_EQUAL_1_10 or self._model_averager is None: + return optimizer_output + + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + self._model_averager.average_parameters(iter(param)) + + return optimizer_output def configure_ddp(self) -> None: log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel") diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 8aec73ffab..db33c4ec72 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -181,7 +181,7 @@ class Strategy(ABC): model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, ) -> Any: - """performs the actual optimizer step. + """Performs the actual optimizer step. Args: optimizer: the optimizer performing the step diff --git a/tests/strategies/test_ddp_strategy_with_comm_hook.py b/tests/strategies/test_ddp_strategy_with_comm_hook.py index ab8beaa934..0003f72b9a 100644 --- a/tests/strategies/test_ddp_strategy_with_comm_hook.py +++ b/tests/strategies/test_ddp_strategy_with_comm_hook.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + +import pytest import torch from pytorch_lightning import Trainer @@ -136,3 +139,80 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir): expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" + + +@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True) +@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters") +def test_post_local_sgd_model_averaging(average_parameters_mock, tmpdir): + """Test that when using DDP with post-localSGD, model averaging is called.""" + model = BoringModel() + + # test regular ddp does not call model averaging + trainer = Trainer( + fast_dev_run=True, + gpus=2, + strategy="ddp", + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + + trainer.fit(model) + average_parameters_mock.assert_not_called() + + # test ddp with post-localSGD does call model averaging + ddp_strategy = DDPStrategy( + ddp_comm_state=post_localSGD.PostLocalSGDState( + process_group=None, + subgroup=None, + start_localSGD_iter=8, + ), + ddp_comm_hook=post_localSGD.post_localSGD_hook, + model_averaging_period=4, + ) + + trainer = Trainer( + fast_dev_run=True, + gpus=2, + strategy=ddp_strategy, + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + + trainer.fit(model) + average_parameters_mock.assert_called() + + +@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True) +@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters") +def test_post_local_sgd_model_averaging_value_error(average_parameters_mock, tmpdir): + """Test that when using DDP with post-localSGD a ValueError is thrown when the optmizer is + ZeroRedundancyOptimizer.""" + from torch.distributed.optim import ZeroRedundancyOptimizer + + class OptimizerModel(BoringModel): + def configure_optimizers(self): + return ZeroRedundancyOptimizer(params=self.parameters(), optimizer_class=torch.optim.Adam, lr=0.01) + + model = OptimizerModel() + strategy = DDPStrategy( + ddp_comm_state=post_localSGD.PostLocalSGDState( + process_group=None, + subgroup=None, + start_localSGD_iter=8, + ), + ddp_comm_hook=post_localSGD.post_localSGD_hook, + model_averaging_period=4, + ) + + trainer = Trainer( + fast_dev_run=True, + gpus=2, + strategy=strategy, + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + + with pytest.raises(ValueError, match="Currently model averaging cannot work with a distributed optimizer"): + trainer.fit(model) + + average_parameters_mock.assert_not_called()