diff --git a/CHANGELOG.md b/CHANGELOG.md index 23e7ced49d..92968eeb94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -144,7 +144,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/)) -- `pl.seed_everyting` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024)) +- `pl.seed_everything` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024)) + + +- Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937)) ### Deprecated diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 62404b66f8..fbcdb405a5 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -20,7 +20,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel @@ -32,10 +32,15 @@ if _FAIRSCALE_AVAILABLE: class DDPShardedPlugin(DDPPlugin): """ Optimizer and gradient sharded training provided by FairScale. """ + _REDUCE_BUFFER_SIZE_DEFAULT = 2**23 # 8M + def configure_ddp(self): self._wrap_optimizers() self._model = ShardedDataParallel( - LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers + LightningShardedDataParallel(self.model), + sharded_optimizer=self.lightning_module.trainer.optimizers, + # For multi-node training, enabling bucketing will improve performance. + reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0, ) setattr(self._model, "require_backward_grad_sync", False) @@ -47,6 +52,12 @@ class DDPShardedPlugin(DDPPlugin): if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) + if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: + is_fp16 = self.lightning_module.trainer.precision == 16 + # For multi-node training, compressing the model shards in fp16 before broadcasting + # improves performance. When using PyTorch AMP, it will not degrade + # the model performance. + zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer trainer = self.lightning_module.trainer @@ -58,7 +69,7 @@ class DDPShardedPlugin(DDPPlugin): return self._reinit_optimizers_with_oss() - def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: + def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: if is_lightning_optimizer(optimizer): optimizer = optimizer._optimizer optimizer.consolidate_state_dict() diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 3c1108b535..979e6d95dd 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -36,6 +36,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, + _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, _GROUP_AVAILABLE, _HOROVOD_AVAILABLE, _HYDRA_AVAILABLE, diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 791cef7ff2..7b36c47285 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -75,6 +75,7 @@ _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') _FAIRSCALE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and not _IS_WINDOWS and _module_available('fairscale.nn') _FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.le, "0.1.3") +_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') _HOROVOD_AVAILABLE = _module_available("horovod.torch") _HYDRA_AVAILABLE = _module_available("hydra")