diff --git a/CHANGELOG.md b/CHANGELOG.md index 8668f8ae50..b85f3c76ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -202,6 +202,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - LightningLite: * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) * Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018)) + * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) ### Changed diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a26b63151f..64fc1a5a97 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -27,6 +27,7 @@ import __main__ import numpy as np import torch import torch.distributed +from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl @@ -181,6 +182,10 @@ class DDPPlugin(ParallelPlugin): self.setup_distributed() + def _setup_model(self, model: Module) -> DistributedDataParallel: + """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" + return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) + def _call_children_scripts(self): # bookkeeping of spawned processes self._check_can_spawn_children() @@ -355,9 +360,7 @@ class DDPPlugin(ParallelPlugin): def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = DistributedDataParallel( - LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs - ) + self._model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 177a58a691..c72cc7f31d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import numpy as np import torch import torch.distributed import torch.multiprocessing as mp +from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl @@ -147,6 +148,10 @@ class DDPSpawnPlugin(ParallelPlugin): smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() + def _setup_model(self, model: Module) -> DistributedDataParallel: + """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" + return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) + def set_world_ranks(self, process_idx: int = 0) -> None: self._local_rank = process_idx if self.cluster_environment is None: @@ -263,9 +268,7 @@ class DDPSpawnPlugin(ParallelPlugin): def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = DistributedDataParallel( - LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs - ) + self._model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 9c53069063..481b9ee1c4 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,11 +13,12 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union import torch from torch import Tensor from torch.nn import Module +from torch.optim import Optimizer from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -60,6 +61,29 @@ class TrainingTypePlugin(ABC): def setup(self) -> None: """Called by the accelerator to finish setup.""" + def _setup_models_and_optimizers( + self, models: List[Module], optimizers: List[Optimizer] + ) -> Tuple[List[Module], List[Optimizer]]: + """Setup multiple models and multiple optimizers together. + + The returned objects are expected to be in the same order they were passed in. The default implementation will + call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists. + """ + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 + models = [self._setup_model(model) for model in models] + optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers] + return models, optimizers + + def _setup_model(self, model: Module) -> Module: + """Performs setup for the model, e.g., by wrapping it by another class.""" + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 + return model + + def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Performs setup for the optimizer, e.g., by wrapping it by another class.""" + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 + return optimizer + @property @abstractmethod def on_gpu(self) -> bool: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 307043d062..1e9ef27455 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1528,7 +1528,7 @@ class Trainer( @property def node_rank(self) -> int: - # some training types define a local rank + # some training types define a node rank return getattr(self.training_type_plugin, "node_rank", 0) @property