Update setup logic in training type plugins [1 / n] (#9994)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-10-19 17:45:36 +02:00 committed by GitHub
parent e95f9b71c1
commit 854bdc042d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 8 deletions

View File

@ -202,6 +202,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- LightningLite:
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
### Changed

View File

@ -27,6 +27,7 @@ import __main__
import numpy as np
import torch
import torch.distributed
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
@ -181,6 +182,10 @@ class DDPPlugin(ParallelPlugin):
self.setup_distributed()
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
def _call_children_scripts(self):
# bookkeeping of spawned processes
self._check_can_spawn_children()
@ -355,9 +360,7 @@ class DDPPlugin(ParallelPlugin):
def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
)
self._model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
def determine_ddp_device_ids(self):

View File

@ -21,6 +21,7 @@ import numpy as np
import torch
import torch.distributed
import torch.multiprocessing as mp
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
@ -147,6 +148,10 @@ class DDPSpawnPlugin(ParallelPlugin):
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is None:
@ -263,9 +268,7 @@ class DDPSpawnPlugin(ParallelPlugin):
def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
)
self._model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
def determine_ddp_device_ids(self):

View File

@ -13,11 +13,12 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import pytorch_lightning as pl
@ -60,6 +61,29 @@ class TrainingTypePlugin(ABC):
def setup(self) -> None:
"""Called by the accelerator to finish setup."""
def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Setup multiple models and multiple optimizers together.
The returned objects are expected to be in the same order they were passed in. The default implementation will
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists.
"""
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
models = [self._setup_model(model) for model in models]
optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers]
return models, optimizers
def _setup_model(self, model: Module) -> Module:
"""Performs setup for the model, e.g., by wrapping it by another class."""
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
return model
def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Performs setup for the optimizer, e.g., by wrapping it by another class."""
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
return optimizer
@property
@abstractmethod
def on_gpu(self) -> bool:

View File

@ -1528,7 +1528,7 @@ class Trainer(
@property
def node_rank(self) -> int:
# some training types define a local rank
# some training types define a node rank
return getattr(self.training_type_plugin, "node_rank", 0)
@property