Update setup logic in training type plugins [1 / n] (#9994)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e95f9b71c1
commit
854bdc042d
|
@ -202,6 +202,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- LightningLite:
|
||||
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
|
||||
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
|
||||
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import __main__
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch.nn import Module
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -181,6 +182,10 @@ class DDPPlugin(ParallelPlugin):
|
|||
|
||||
self.setup_distributed()
|
||||
|
||||
def _setup_model(self, model: Module) -> DistributedDataParallel:
|
||||
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
|
||||
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
|
||||
|
||||
def _call_children_scripts(self):
|
||||
# bookkeeping of spawned processes
|
||||
self._check_can_spawn_children()
|
||||
|
@ -355,9 +360,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
|
||||
def configure_ddp(self) -> None:
|
||||
self.pre_configure_ddp()
|
||||
self._model = DistributedDataParallel(
|
||||
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
|
||||
)
|
||||
self._model = self._setup_model(LightningDistributedModule(self.model))
|
||||
self._register_ddp_hooks()
|
||||
|
||||
def determine_ddp_device_ids(self):
|
||||
|
|
|
@ -21,6 +21,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn import Module
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -147,6 +148,10 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
smp = mp.get_context("spawn")
|
||||
self.mp_queue = smp.SimpleQueue()
|
||||
|
||||
def _setup_model(self, model: Module) -> DistributedDataParallel:
|
||||
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
|
||||
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
|
||||
|
||||
def set_world_ranks(self, process_idx: int = 0) -> None:
|
||||
self._local_rank = process_idx
|
||||
if self.cluster_environment is None:
|
||||
|
@ -263,9 +268,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
|
||||
def configure_ddp(self) -> None:
|
||||
self.pre_configure_ddp()
|
||||
self._model = DistributedDataParallel(
|
||||
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
|
||||
)
|
||||
self._model = self._setup_model(LightningDistributedModule(self.model))
|
||||
self._register_ddp_hooks()
|
||||
|
||||
def determine_ddp_device_ids(self):
|
||||
|
|
|
@ -13,11 +13,12 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Union
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -60,6 +61,29 @@ class TrainingTypePlugin(ABC):
|
|||
def setup(self) -> None:
|
||||
"""Called by the accelerator to finish setup."""
|
||||
|
||||
def _setup_models_and_optimizers(
|
||||
self, models: List[Module], optimizers: List[Optimizer]
|
||||
) -> Tuple[List[Module], List[Optimizer]]:
|
||||
"""Setup multiple models and multiple optimizers together.
|
||||
|
||||
The returned objects are expected to be in the same order they were passed in. The default implementation will
|
||||
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists.
|
||||
"""
|
||||
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
|
||||
models = [self._setup_model(model) for model in models]
|
||||
optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers]
|
||||
return models, optimizers
|
||||
|
||||
def _setup_model(self, model: Module) -> Module:
|
||||
"""Performs setup for the model, e.g., by wrapping it by another class."""
|
||||
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
|
||||
return model
|
||||
|
||||
def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||
"""Performs setup for the optimizer, e.g., by wrapping it by another class."""
|
||||
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
|
||||
return optimizer
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def on_gpu(self) -> bool:
|
||||
|
|
|
@ -1528,7 +1528,7 @@ class Trainer(
|
|||
|
||||
@property
|
||||
def node_rank(self) -> int:
|
||||
# some training types define a local rank
|
||||
# some training types define a node rank
|
||||
return getattr(self.training_type_plugin, "node_rank", 0)
|
||||
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue