Update setup logic in training type plugins (sharded) [4 / 4] (#10028)
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
84706a256e
commit
4ea72a9365
|
@ -213,6 +213,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
|
||||
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
|
||||
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
|
||||
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028))
|
||||
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))
|
||||
|
||||
|
||||
|
|
|
@ -12,9 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Generator, Optional
|
||||
from typing import Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
|
@ -33,24 +35,48 @@ if _FAIRSCALE_AVAILABLE:
|
|||
class DDPShardedPlugin(DDPPlugin):
|
||||
"""Optimizer and gradient sharded training provided by FairScale."""
|
||||
|
||||
_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M
|
||||
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._precision = None
|
||||
|
||||
def configure_ddp(self) -> None:
|
||||
self._wrap_optimizers()
|
||||
|
||||
trainer = self.lightning_module.trainer
|
||||
if "reduce_buffer_size" not in self._ddp_kwargs:
|
||||
# For multi-node training, enabling bucketing will improve performance.
|
||||
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
|
||||
|
||||
self._model = ShardedDataParallel(
|
||||
LightningShardedDataParallel(self.model),
|
||||
sharded_optimizer=self.lightning_module.trainer.optimizers,
|
||||
**self._ddp_kwargs
|
||||
[self._model], optimizers = self._setup_models_and_optimizers(
|
||||
models=[LightningShardedDataParallel(self.model)],
|
||||
optimizers=trainer.optimizers,
|
||||
)
|
||||
setattr(self._model, "require_backward_grad_sync", False)
|
||||
trainer.optimizers = optimizers
|
||||
trainer.convert_to_lightning_optimizers()
|
||||
|
||||
def _reinit_optimizers_with_oss(self):
|
||||
optimizers = self.lightning_module.trainer.optimizers
|
||||
def _setup_models_and_optimizers(
|
||||
self, models: List[Module], optimizers: List[Optimizer]
|
||||
) -> Tuple[List[Module], List[Optimizer]]:
|
||||
"""Wraps the model and optimizers with fairscale components.
|
||||
|
||||
Currently only one model can be setup at once.
|
||||
|
||||
Return:
|
||||
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
|
||||
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
|
||||
"""
|
||||
if len(models) > 1:
|
||||
raise ValueError(
|
||||
"DDPSharded only supports setting up a single model with one or several optimizers."
|
||||
f" Got {len(models)} models."
|
||||
)
|
||||
|
||||
optimizers = self._wrap_optimizers(optimizers)
|
||||
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
|
||||
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
|
||||
return [model], optimizers
|
||||
|
||||
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
|
||||
for x, optimizer in enumerate(optimizers):
|
||||
if isinstance(optimizer, LightningOptimizer):
|
||||
optimizer = optimizer._optimizer
|
||||
|
@ -58,7 +84,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
optim_class = type(optimizer)
|
||||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
||||
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
|
||||
precision = self.lightning_module.trainer.precision
|
||||
precision = self._precision or self.lightning_module.trainer.precision
|
||||
is_fp16 = precision in ("mixed", 16)
|
||||
# For multi-node training, compressing the model shards in fp16 before broadcasting
|
||||
# improves performance. When using PyTorch AMP, it will not degrade
|
||||
|
@ -66,14 +92,13 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
|
||||
optimizers[x] = zero_optimizer
|
||||
del optimizer
|
||||
trainer = self.lightning_module.trainer
|
||||
trainer.optimizers = optimizers
|
||||
trainer.convert_to_lightning_optimizers()
|
||||
return optimizers
|
||||
|
||||
def _wrap_optimizers(self):
|
||||
if self.model.trainer.state.fn != TrainerFn.FITTING:
|
||||
return
|
||||
self._reinit_optimizers_with_oss()
|
||||
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
|
||||
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
|
||||
return optimizers
|
||||
|
||||
return self._reinit_optimizers_with_oss(optimizers)
|
||||
|
||||
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
|
||||
if isinstance(optimizer, LightningOptimizer):
|
||||
|
|
|
@ -13,9 +13,11 @@
|
|||
# limitations under the License.
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.queues import SimpleQueue
|
||||
from typing import Dict, Generator, Optional
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
|
||||
|
@ -36,29 +38,49 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
"""Optimizer sharded training provided by FairScale."""
|
||||
|
||||
def configure_ddp(self) -> None:
|
||||
self._wrap_optimizers()
|
||||
self._model = ShardedDataParallel(
|
||||
LightningShardedDataParallel(self.model),
|
||||
sharded_optimizer=self.lightning_module.trainer.optimizers,
|
||||
**self._ddp_kwargs
|
||||
trainer = self.lightning_module.trainer
|
||||
[self._model], optimizers = self._setup_models_and_optimizers(
|
||||
models=[LightningShardedDataParallel(self.model)],
|
||||
optimizers=trainer.optimizers,
|
||||
)
|
||||
setattr(self._model, "require_backward_grad_sync", False)
|
||||
trainer.optimizers = optimizers
|
||||
|
||||
def _reinit_optimizers_with_oss(self):
|
||||
optimizers = self.lightning_module.trainer.optimizers
|
||||
def _setup_models_and_optimizers(
|
||||
self, models: List[Module], optimizers: List[Optimizer]
|
||||
) -> Tuple[List[Module], List[Optimizer]]:
|
||||
"""Wraps the model and optimizers with fairscale components.
|
||||
|
||||
Currently only one model can be setup at once.
|
||||
|
||||
Return:
|
||||
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
|
||||
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
|
||||
"""
|
||||
if len(models) > 1:
|
||||
raise ValueError(
|
||||
f"DDPShardedSpawn only supports setting up a single model with one or several optimizers."
|
||||
f" Got {len(models)} models."
|
||||
)
|
||||
|
||||
optimizers = self._wrap_optimizers(optimizers)
|
||||
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
|
||||
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
|
||||
return [model], optimizers
|
||||
|
||||
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
|
||||
for x, optimizer in enumerate(optimizers):
|
||||
if not isinstance(optimizer, OSS):
|
||||
optim_class = type(optimizer)
|
||||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
||||
optimizers[x] = zero_optimizer
|
||||
del optimizer
|
||||
trainer = self.lightning_module.trainer
|
||||
trainer.optimizers = optimizers
|
||||
return optimizers
|
||||
|
||||
def _wrap_optimizers(self):
|
||||
if self.model.trainer.state.fn != TrainerFn.FITTING:
|
||||
return
|
||||
self._reinit_optimizers_with_oss()
|
||||
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
|
||||
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
|
||||
return optimizers
|
||||
|
||||
return self._reinit_optimizers_with_oss(optimizers)
|
||||
|
||||
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
|
||||
if isinstance(optimizer, OSS):
|
||||
|
|
Loading…
Reference in New Issue