Restrict setup methods to accept a single model (#10064)

This commit is contained in:
Adrian Wälchli 2021-10-25 18:32:57 +02:00 committed by GitHub
parent cfb2d87765
commit d3e5a43546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 53 deletions

View File

@ -214,10 +214,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- LightningLite:
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018), [#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022))
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028))
* Implemented `DeepSpeedPlugin._setup_model_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064))
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_model_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064))
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))

View File

@ -379,30 +379,28 @@ class DeepSpeedPlugin(DDPPlugin):
self.init_deepspeed()
self.barrier()
def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Setup multiple models and multiple optimizers together.
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.
Currently only one model paired with a single optimizer is supported.
Currently only a single optimizer is supported.
Return:
A list with one model wrapped into a :class:`deepspeed.DeepSpeedEngine` and list with a single
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
deepspeed optimizer.
"""
if not (len(models) == len(optimizers) == 1):
if len(optimizers) != 1:
raise ValueError(
f"Currently only one model and one optimizer is supported with DeepSpeed."
f" Got {len(models)} models and {len(optimizers)} optimizers instead."
f"Currently only one optimizer is supported with DeepSpeed."
f" Got {len(optimizers)} optimizers instead."
)
# train_micro_batch_size_per_gpu is used for throughput logging purposes
# normally we set this to the batch size, but it is not available here unless the user provides it
# as part of the config
self.config.setdefault("train_micro_batch_size_per_gpu", 1)
self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0])
self._model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
self._set_deepspeed_activation_checkpointing()
return [self._model], [optimizer]
return self._model, [optimizer]
def _setup_model_and_optimizer(
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None

View File

@ -47,33 +47,23 @@ class DDPShardedPlugin(DDPPlugin):
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
self._model, optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model),
optimizers=trainer.optimizers,
)
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Currently only one model can be setup at once.
Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
"DDPSharded only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)
optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
return [model], optimizers
model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):

View File

@ -39,32 +39,22 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
def configure_ddp(self) -> None:
trainer = self.lightning_module.trainer
[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
self._model, optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model),
optimizers=trainer.optimizers,
)
trainer.optimizers = optimizers
def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Currently only one model can be setup at once.
Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
f"DDPShardedSpawn only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)
optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
return [model], optimizers
model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):

View File

@ -61,18 +61,16 @@ class TrainingTypePlugin(ABC):
def setup(self) -> None:
"""Called by the accelerator to finish setup."""
def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Setup multiple models and multiple optimizers together.
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.
The returned objects are expected to be in the same order they were passed in. The default implementation will
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists.
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs.
"""
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
models = [self._setup_model(model) for model in models]
model = self._setup_model(model)
optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers]
return models, optimizers
return model, optimizers
def _setup_model(self, model: Module) -> Module:
"""Performs setup for the model, e.g., by wrapping it by another class."""