Update setup logic in training type plugins (sharded) [4 / 4] (#10028)

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Adrian Wälchli 2021-10-21 10:35:01 +02:00 committed by GitHub
parent 84706a256e
commit 4ea72a9365
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 34 deletions

View File

@ -213,6 +213,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028))
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))

View File

@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, Optional
from typing import Dict, Generator, List, Optional, Tuple, Union
import torch
from torch.nn import Module
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
@ -33,24 +35,48 @@ if _FAIRSCALE_AVAILABLE:
class DDPShardedPlugin(DDPPlugin):
"""Optimizer and gradient sharded training provided by FairScale."""
_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._precision = None
def configure_ddp(self) -> None:
self._wrap_optimizers()
trainer = self.lightning_module.trainer
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
**self._ddp_kwargs
[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
optimizers=trainer.optimizers,
)
setattr(self._model, "require_backward_grad_sync", False)
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Currently only one model can be setup at once.
Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
"DDPSharded only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)
optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
return [model], optimizers
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
@ -58,7 +84,7 @@ class DDPShardedPlugin(DDPPlugin):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
precision = self.lightning_module.trainer.precision
precision = self._precision or self.lightning_module.trainer.precision
is_fp16 = precision in ("mixed", 16)
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
@ -66,14 +92,13 @@ class DDPShardedPlugin(DDPPlugin):
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
optimizers[x] = zero_optimizer
del optimizer
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
return optimizers
def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers
return self._reinit_optimizers_with_oss(optimizers)
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):

View File

@ -13,9 +13,11 @@
# limitations under the License.
from contextlib import contextmanager
from multiprocessing.queues import SimpleQueue
from typing import Dict, Generator, Optional
from typing import Dict, Generator, List, Optional, Tuple
import torch
from torch.nn import Module
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
@ -36,29 +38,49 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"""Optimizer sharded training provided by FairScale."""
def configure_ddp(self) -> None:
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
**self._ddp_kwargs
trainer = self.lightning_module.trainer
[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
optimizers=trainer.optimizers,
)
setattr(self._model, "require_backward_grad_sync", False)
trainer.optimizers = optimizers
def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Currently only one model can be setup at once.
Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
f"DDPShardedSpawn only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)
optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
return [model], optimizers
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
optimizers[x] = zero_optimizer
del optimizer
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
return optimizers
def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers
return self._reinit_optimizers_with_oss(optimizers)
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, OSS):