Set smarter default for DDP sharded for performance optimization (#6937)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
dd5ec75e48
commit
52a5cee0a7
|
@ -144,7 +144,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/))
|
||||
|
||||
|
||||
- `pl.seed_everyting` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024))
|
||||
- `pl.seed_everything` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024))
|
||||
|
||||
|
||||
- Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
|
|
@ -20,7 +20,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.core.optimizer import is_lightning_optimizer
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
@ -32,10 +32,15 @@ if _FAIRSCALE_AVAILABLE:
|
|||
class DDPShardedPlugin(DDPPlugin):
|
||||
""" Optimizer and gradient sharded training provided by FairScale. """
|
||||
|
||||
_REDUCE_BUFFER_SIZE_DEFAULT = 2**23 # 8M
|
||||
|
||||
def configure_ddp(self):
|
||||
self._wrap_optimizers()
|
||||
self._model = ShardedDataParallel(
|
||||
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
|
||||
LightningShardedDataParallel(self.model),
|
||||
sharded_optimizer=self.lightning_module.trainer.optimizers,
|
||||
# For multi-node training, enabling bucketing will improve performance.
|
||||
reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0,
|
||||
)
|
||||
setattr(self._model, "require_backward_grad_sync", False)
|
||||
|
||||
|
@ -47,6 +52,12 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
if not isinstance(optimizer, OSS):
|
||||
optim_class = type(optimizer)
|
||||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
||||
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
|
||||
is_fp16 = self.lightning_module.trainer.precision == 16
|
||||
# For multi-node training, compressing the model shards in fp16 before broadcasting
|
||||
# improves performance. When using PyTorch AMP, it will not degrade
|
||||
# the model performance.
|
||||
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
|
||||
optimizers[x] = zero_optimizer
|
||||
del optimizer
|
||||
trainer = self.lightning_module.trainer
|
||||
|
@ -58,7 +69,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
return
|
||||
self._reinit_optimizers_with_oss()
|
||||
|
||||
def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
|
||||
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
|
||||
if is_lightning_optimizer(optimizer):
|
||||
optimizer = optimizer._optimizer
|
||||
optimizer.consolidate_state_dict()
|
||||
|
|
|
@ -36,6 +36,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
|
|||
_DEEPSPEED_AVAILABLE,
|
||||
_FAIRSCALE_AVAILABLE,
|
||||
_FAIRSCALE_PIPE_AVAILABLE,
|
||||
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE,
|
||||
_GROUP_AVAILABLE,
|
||||
_HOROVOD_AVAILABLE,
|
||||
_HYDRA_AVAILABLE,
|
||||
|
|
|
@ -75,6 +75,7 @@ _BOLTS_AVAILABLE = _module_available('pl_bolts')
|
|||
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
|
||||
_FAIRSCALE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and not _IS_WINDOWS and _module_available('fairscale.nn')
|
||||
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.le, "0.1.3")
|
||||
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
|
||||
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')
|
||||
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||
_HYDRA_AVAILABLE = _module_available("hydra")
|
||||
|
|
Loading…
Reference in New Issue