mirror of
https://github.com/Lightning-AI/lightning.git
synced 2025-02-26 20:25:16 +00:00
Add DDPSpawnPlugin.spawn()
(#10018)
This commit is contained in:
parent
0aa220b46b
commit
bcb94de90e
@ -201,7 +201,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
- LightningLite:
|
||||
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
|
||||
|
||||
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
|
||||
|
||||
### Changed
|
||||
|
||||
|
@ -15,7 +15,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from multiprocessing.queues import SimpleQueue
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -155,38 +155,45 @@ class DDPSpawnPlugin(ParallelPlugin):
|
||||
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
|
||||
rank_zero_only.rank = self.cluster_environment.global_rank()
|
||||
|
||||
def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict:
|
||||
return {"args": (trainer, self.mp_queue), "nprocs": self.num_processes}
|
||||
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
|
||||
return {"nprocs": self.num_processes}
|
||||
|
||||
def start_training(self, trainer: "pl.Trainer") -> None:
|
||||
mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
|
||||
self.spawn(self.new_process, trainer, self.mp_queue)
|
||||
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
|
||||
trainer.optimizers = []
|
||||
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> None:
|
||||
mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
|
||||
self.spawn(self.new_process, trainer, self.mp_queue)
|
||||
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> None:
|
||||
mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
|
||||
self.spawn(self.new_process, trainer, self.mp_queue)
|
||||
|
||||
def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
|
||||
self.mp_queue = mp_queue
|
||||
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None:
|
||||
"""Spawn processes that run the given function.
|
||||
|
||||
Args:
|
||||
function: The function to spawn processes from.
|
||||
*args: Optional positional arguments that will be passed to the function in addition to the process index.
|
||||
These arguments must be pickleable.
|
||||
**kwargs: Optional named arguments that will be passed to the function in addition to the process index.
|
||||
These arguments must be pickleable.
|
||||
"""
|
||||
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
|
||||
mp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs())
|
||||
|
||||
def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any) -> None:
|
||||
self._worker_setup(process_idx)
|
||||
function(*args, **kwargs)
|
||||
|
||||
def _worker_setup(self, process_idx: int):
|
||||
reset_seed()
|
||||
|
||||
self.set_world_ranks(process_idx)
|
||||
|
||||
# set warning rank
|
||||
rank_zero_only.rank = self.global_rank
|
||||
|
||||
# set up server using proc 0's ip address
|
||||
# try to init for 20 times at max in case ports are taken
|
||||
# where to store ip_table
|
||||
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size)
|
||||
|
||||
# TODO: we moved it to the trainer.fit after calling pre_dispatch
|
||||
# ... need to double check that it is the correct place
|
||||
# self.trainer.call_setup_hook(self.model)
|
||||
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
|
||||
self.mp_queue = mp_queue
|
||||
|
||||
# move the model to the correct device
|
||||
self.model_to_device()
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.queues import SimpleQueue
|
||||
from typing import Dict, Generator, Optional
|
||||
|
||||
import torch
|
||||
@ -100,13 +101,13 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
||||
def post_training_step(self):
|
||||
pass
|
||||
|
||||
def new_process(self, process_idx, trainer, mp_queue):
|
||||
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
|
||||
# Ensure that the scaler points to the correct process group
|
||||
# which is re-initialized in a new process
|
||||
precision_plugin = trainer.accelerator.precision_plugin
|
||||
if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin):
|
||||
precision_plugin.scaler = ShardedGradScaler()
|
||||
super().new_process(process_idx, trainer, mp_queue)
|
||||
return super().new_process(trainer, mp_queue)
|
||||
|
||||
@classmethod
|
||||
def register_plugins(cls, plugin_registry: Dict) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user