Set smarter default for DDP sharded for performance optimization (#6937)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
shuyingsunshine21 2021-04-26 15:31:34 -07:00 committed by GitHub
parent dd5ec75e48
commit 52a5cee0a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 4 deletions

View File

@ -144,7 +144,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/))
- `pl.seed_everyting` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024))
- `pl.seed_everything` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024))
- Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937))
### Deprecated

View File

@ -20,7 +20,7 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
@ -32,10 +32,15 @@ if _FAIRSCALE_AVAILABLE:
class DDPShardedPlugin(DDPPlugin):
""" Optimizer and gradient sharded training provided by FairScale. """
_REDUCE_BUFFER_SIZE_DEFAULT = 2**23 # 8M
def configure_ddp(self):
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
# For multi-node training, enabling bucketing will improve performance.
reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0,
)
setattr(self._model, "require_backward_grad_sync", False)
@ -47,6 +52,12 @@ class DDPShardedPlugin(DDPPlugin):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
is_fp16 = self.lightning_module.trainer.precision == 16
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
# the model performance.
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
optimizers[x] = zero_optimizer
del optimizer
trainer = self.lightning_module.trainer
@ -58,7 +69,7 @@ class DDPShardedPlugin(DDPPlugin):
return
self._reinit_optimizers_with_oss()
def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if is_lightning_optimizer(optimizer):
optimizer = optimizer._optimizer
optimizer.consolidate_state_dict()

View File

@ -36,6 +36,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_FAIRSCALE_PIPE_AVAILABLE,
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE,
_GROUP_AVAILABLE,
_HOROVOD_AVAILABLE,
_HYDRA_AVAILABLE,

View File

@ -75,6 +75,7 @@ _BOLTS_AVAILABLE = _module_available('pl_bolts')
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
_FAIRSCALE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and not _IS_WINDOWS and _module_available('fairscale.nn')
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.le, "0.1.3")
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _module_available("hydra")