From 366fb39d2e021ef6fa663b972f2f4e6d92b62775 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Thu, 26 Aug 2021 00:24:49 -0700 Subject: [PATCH] Support post-localSGD in Lightning DDP plugin (#8967) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: ananthsub Co-authored-by: Adrian Wälchli --- .../plugins/training_type/ddp.py | 59 +++++++++++++++++++ pytorch_lightning/utilities/distributed.py | 14 +++++ .../plugins/test_ddp_plugin_with_comm_hook.py | 33 ++++++++++- 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 787353be30..aeb43fcdeb 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -29,17 +29,21 @@ import torch import torch.distributed from torch.nn.parallel.distributed import DistributedDataParallel +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( + _FAIRSCALE_AVAILABLE, _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, + _TORCH_GREATER_EQUAL_1_10, rank_zero_deprecation, rank_zero_warn, ) @@ -53,11 +57,19 @@ from pytorch_lightning.utilities.distributed import ( from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed +if _TORCH_GREATER_EQUAL_1_10: + from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer + +if _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS if _HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook +if _TORCH_GREATER_EQUAL_1_10: + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + import torch.distributed.algorithms.model_averaging.averagers as averagers log = logging.getLogger(__name__) @@ -83,6 +95,7 @@ class DDPPlugin(ParallelPlugin): ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, + model_averaging_period: Optional[int] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: super().__init__( @@ -110,6 +123,7 @@ class DDPPlugin(ParallelPlugin): self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper + self._model_averaging_period = model_averaging_period self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None self.set_world_ranks() @@ -302,6 +316,51 @@ class DDPPlugin(ParallelPlugin): ddp_comm_wrapper=self._ddp_comm_wrapper, ) + # Post-localSDG is only available after 1.9, + # and `torch.distributed.optim` package currently is not available on Windows. + if ( + _TORCH_GREATER_EQUAL_1_10 + and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState) + and self.lightning_module.trainer.state.fn == TrainerFn.FITTING + ): + self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) + + def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): + optimizers = self.lightning_module.trainer.optimizers + if self._model_averaging_period is None: + raise ValueError( + "Post-localSGD algorithm is used, " "but model averaging period is not provided to DDP plugin." + ) + averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps) + for x, optimizer in enumerate(optimizers): + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + + if ( + isinstance(optimizer, DistributedOptimizer) + or isinstance(optimizer, ZeroRedundancyOptimizer) + or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)) + ): + raise ValueError( + f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer." + ) + + if isinstance(optimizer, PostLocalSGDOptimizer): + continue + + optim_class = type(optimizer) + post_localSGD_optimizer = PostLocalSGDOptimizer( + params=optimizer.param_groups, + optimizer_class=optim_class, + averager=averager, + **optimizer.defaults, + ) + optimizers[x] = post_localSGD_optimizer + del optimizer + trainer = self.lightning_module.trainer + trainer.optimizers = optimizers + trainer.convert_to_lightning_optimizers() + def configure_ddp(self): self.pre_configure_ddp() self._model = DistributedDataParallel( diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 68868a0ff7..4f254b6824 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -284,12 +284,14 @@ def register_ddp_comm_hook( .. warning :: DDP communication wrapper needs pytorch version at least 1.9.0 + Post-localSGD hook needs pytorch version at least 1.9.0 Example: from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as default, powerSGD_hook as powerSGD, + post_localSGD_hook as post_localSGD, ) # fp16_compress_hook for compress gradients @@ -309,6 +311,18 @@ def register_ddp_comm_hook( ddp_comm_hook=powerSGD.powerSGD_hook, ) + # post_localSGD_hook + subgroup, _ = torch.distributed.new_subgroups() + register_comm_hook( + model=ddp_model, + state=post_localSGD.PostLocalSGDState( + process_group=None, + subgroup=subgroup, + start_localSGD_iter=1_000, + ), + ddp_comm_hook=post_localSGD.post_localSGD_hook, + ) + # fp16_compress_wrapper combined with other communication hook register_ddp_comm_hook( model=ddp_model, diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 49c4cd18ef..1e5968f1a0 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -15,13 +15,15 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_10 from tests.helpers import BoringModel from tests.helpers.runif import RunIf if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8: from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD +if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_10: + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD @RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True) @@ -108,3 +110,32 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): ) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" + + +@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, special=True) +def test_ddp_post_local_sgd_comm_hook(tmpdir): + """Test for DDP post-localSGD hook.""" + model = BoringModel() + + training_type_plugin = DDPPlugin( + ddp_comm_state=post_localSGD.PostLocalSGDState( + process_group=None, + subgroup=None, + start_localSGD_iter=8, + ), + ddp_comm_hook=post_localSGD.post_localSGD_hook, + model_averaging_period=4, + sync_batchnorm=True, + ) + trainer = Trainer( + fast_dev_run=True, + gpus=2, + plugins=[training_type_plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) + trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ + assert trainer_comm_hook == expected_comm_hook + assert trainer.state.finished, f"Training failed with {trainer.state}"