Add DDPSpawnPlugin.spawn() (#10018)

This commit is contained in:
Adrian Wälchli 2021-10-19 16:34:47 +02:00 committed by GitHub
parent 0aa220b46b
commit bcb94de90e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 21 deletions

View File

@ -201,7 +201,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- LightningLite:
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
### Changed

View File

@ -15,7 +15,7 @@ import logging
import os
import re
from multiprocessing.queues import SimpleQueue
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
@ -155,38 +155,45 @@ class DDPSpawnPlugin(ParallelPlugin):
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict:
return {"args": (trainer, self.mp_queue), "nprocs": self.num_processes}
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {"nprocs": self.num_processes}
def start_training(self, trainer: "pl.Trainer") -> None:
mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
self.spawn(self.new_process, trainer, self.mp_queue)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []
def start_evaluating(self, trainer: "pl.Trainer") -> None:
mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
self.spawn(self.new_process, trainer, self.mp_queue)
def start_predicting(self, trainer: "pl.Trainer") -> None:
mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
self.spawn(self.new_process, trainer, self.mp_queue)
def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
self.mp_queue = mp_queue
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None:
"""Spawn processes that run the given function.
Args:
function: The function to spawn processes from.
*args: Optional positional arguments that will be passed to the function in addition to the process index.
These arguments must be pickleable.
**kwargs: Optional named arguments that will be passed to the function in addition to the process index.
These arguments must be pickleable.
"""
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
mp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs())
def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any) -> None:
self._worker_setup(process_idx)
function(*args, **kwargs)
def _worker_setup(self, process_idx: int):
reset_seed()
self.set_world_ranks(process_idx)
# set warning rank
rank_zero_only.rank = self.global_rank
# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size)
# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
self.mp_queue = mp_queue
# move the model to the correct device
self.model_to_device()

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from multiprocessing.queues import SimpleQueue
from typing import Dict, Generator, Optional
import torch
@ -100,13 +101,13 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
def post_training_step(self):
pass
def new_process(self, process_idx, trainer, mp_queue):
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
# Ensure that the scaler points to the correct process group
# which is re-initialized in a new process
precision_plugin = trainer.accelerator.precision_plugin
if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin):
precision_plugin.scaler = ShardedGradScaler()
super().new_process(process_idx, trainer, mp_queue)
return super().new_process(trainer, mp_queue)
@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None: