diff --git a/CHANGELOG.md b/CHANGELOG.md index 398d038595..16e7540381 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -213,6 +213,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) * Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010)) * Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009)) + * Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028)) * Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023)) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index d684a34784..63ac7f5105 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, Optional +from typing import Dict, Generator, List, Optional, Tuple, Union import torch +from torch.nn import Module +from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer @@ -33,24 +35,48 @@ if _FAIRSCALE_AVAILABLE: class DDPShardedPlugin(DDPPlugin): """Optimizer and gradient sharded training provided by FairScale.""" - _REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M + _REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._precision = None def configure_ddp(self) -> None: - self._wrap_optimizers() - + trainer = self.lightning_module.trainer if "reduce_buffer_size" not in self._ddp_kwargs: # For multi-node training, enabling bucketing will improve performance. self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 - self._model = ShardedDataParallel( - LightningShardedDataParallel(self.model), - sharded_optimizer=self.lightning_module.trainer.optimizers, - **self._ddp_kwargs + [self._model], optimizers = self._setup_models_and_optimizers( + models=[LightningShardedDataParallel(self.model)], + optimizers=trainer.optimizers, ) - setattr(self._model, "require_backward_grad_sync", False) + trainer.optimizers = optimizers + trainer.convert_to_lightning_optimizers() - def _reinit_optimizers_with_oss(self): - optimizers = self.lightning_module.trainer.optimizers + def _setup_models_and_optimizers( + self, models: List[Module], optimizers: List[Optimizer] + ) -> Tuple[List[Module], List[Optimizer]]: + """Wraps the model and optimizers with fairscale components. + + Currently only one model can be setup at once. + + Return: + A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module + and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. + """ + if len(models) > 1: + raise ValueError( + "DDPSharded only supports setting up a single model with one or several optimizers." + f" Got {len(models)} models." + ) + + optimizers = self._wrap_optimizers(optimizers) + model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) + setattr(model, "require_backward_grad_sync", False) # TODO: needed? + return [model], optimizers + + def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer @@ -58,7 +84,7 @@ class DDPShardedPlugin(DDPPlugin): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - precision = self.lightning_module.trainer.precision + precision = self._precision or self.lightning_module.trainer.precision is_fp16 = precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade @@ -66,14 +92,13 @@ class DDPShardedPlugin(DDPPlugin): zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer - trainer = self.lightning_module.trainer - trainer.optimizers = optimizers - trainer.convert_to_lightning_optimizers() + return optimizers - def _wrap_optimizers(self): - if self.model.trainer.state.fn != TrainerFn.FITTING: - return - self._reinit_optimizers_with_oss() + def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: + if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: + return optimizers + + return self._reinit_optimizers_with_oss(optimizers) def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: if isinstance(optimizer, LightningOptimizer): diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 78b54d029a..5d48c489a3 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -13,9 +13,11 @@ # limitations under the License. from contextlib import contextmanager from multiprocessing.queues import SimpleQueue -from typing import Dict, Generator, Optional +from typing import Dict, Generator, List, Optional, Tuple import torch +from torch.nn import Module +from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin @@ -36,29 +38,49 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): """Optimizer sharded training provided by FairScale.""" def configure_ddp(self) -> None: - self._wrap_optimizers() - self._model = ShardedDataParallel( - LightningShardedDataParallel(self.model), - sharded_optimizer=self.lightning_module.trainer.optimizers, - **self._ddp_kwargs + trainer = self.lightning_module.trainer + [self._model], optimizers = self._setup_models_and_optimizers( + models=[LightningShardedDataParallel(self.model)], + optimizers=trainer.optimizers, ) - setattr(self._model, "require_backward_grad_sync", False) + trainer.optimizers = optimizers - def _reinit_optimizers_with_oss(self): - optimizers = self.lightning_module.trainer.optimizers + def _setup_models_and_optimizers( + self, models: List[Module], optimizers: List[Optimizer] + ) -> Tuple[List[Module], List[Optimizer]]: + """Wraps the model and optimizers with fairscale components. + + Currently only one model can be setup at once. + + Return: + A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module + and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. + """ + if len(models) > 1: + raise ValueError( + f"DDPShardedSpawn only supports setting up a single model with one or several optimizers." + f" Got {len(models)} models." + ) + + optimizers = self._wrap_optimizers(optimizers) + model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) + setattr(model, "require_backward_grad_sync", False) # TODO: needed? + return [model], optimizers + + def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) optimizers[x] = zero_optimizer del optimizer - trainer = self.lightning_module.trainer - trainer.optimizers = optimizers + return optimizers - def _wrap_optimizers(self): - if self.model.trainer.state.fn != TrainerFn.FITTING: - return - self._reinit_optimizers_with_oss() + def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: + if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: + return optimizers + + return self._reinit_optimizers_with_oss(optimizers) def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: if isinstance(optimizer, OSS):