1/n Simplify spawn plugins: Simplify handling of multiprocessing queue (#10034)
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
541b983b90
commit
98cb7e8790
12
CHANGELOG.md
12
CHANGELOG.md
|
@ -80,7 +80,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649))
|
||||
|
||||
|
||||
-
|
||||
- The `DDPSpawnPlugin` no longer overrides the `post_dispatch` plugin hook ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
|
||||
|
||||
|
||||
- The `LightningModule.{add_to_queue,get_from_queue}` hooks no longer get a `torch.multiprocessing.SimpleQueue` and instead receive a list based queue ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
@ -188,6 +192,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))
|
||||
|
||||
|
||||
- Removed the property `TrainingTypePlugin.results` and corresponding properties in subclasses ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
|
||||
|
||||
|
||||
- Removed the `mp_queue` attribute from `DDPSpawnPlugin` and `TPUSpawnPlugin` ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
|
||||
|
||||
|
||||
- Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849))
|
||||
|
||||
|
||||
|
|
|
@ -1917,7 +1917,7 @@ class LightningModule(
|
|||
)
|
||||
return get_model_size_mb(self)
|
||||
|
||||
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None:
|
||||
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
|
||||
sharing, we cast the data to numpy.
|
||||
|
||||
|
@ -1931,7 +1931,7 @@ class LightningModule(
|
|||
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
|
||||
self.trainer.training_type_plugin.add_to_queue(self.trainer, queue)
|
||||
|
||||
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None:
|
||||
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
|
||||
we cast back the data to ``torch.Tensor``.
|
||||
|
||||
|
|
|
@ -14,8 +14,9 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import UserList
|
||||
from multiprocessing.queues import SimpleQueue
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -45,7 +46,7 @@ from pytorch_lightning.utilities.distributed import (
|
|||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
|
@ -80,7 +81,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.sync_batchnorm = False
|
||||
self._ddp_kwargs = kwargs
|
||||
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
|
||||
self.mp_queue = None
|
||||
self._ddp_comm_state = ddp_comm_state
|
||||
self._ddp_comm_hook = ddp_comm_hook
|
||||
self._ddp_comm_wrapper = ddp_comm_wrapper
|
||||
|
@ -101,15 +101,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
def local_rank(self) -> int:
|
||||
return self._local_rank
|
||||
|
||||
def __getstate__(self):
|
||||
"""Makes this plugin pickleable without destroying the queue in the current process."""
|
||||
state = self.__dict__.copy()
|
||||
state["mp_queue"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
|
||||
@property
|
||||
def root_device(self):
|
||||
return self.parallel_devices[self.local_rank]
|
||||
|
@ -125,9 +116,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
|
||||
# pass in a state q
|
||||
smp = mp.get_context("spawn")
|
||||
self.mp_queue = smp.SimpleQueue()
|
||||
super().setup(trainer)
|
||||
|
||||
def _setup_model(self, model: Module) -> DistributedDataParallel:
|
||||
|
@ -145,18 +133,24 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
|
||||
return {"nprocs": self.num_processes}
|
||||
|
||||
def start_training(self, trainer: "pl.Trainer") -> None:
|
||||
self.spawn(self.new_process, trainer, self.mp_queue)
|
||||
def start_training(self, trainer: "pl.Trainer") -> Any:
|
||||
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
|
||||
self.__recover_results_in_main_process(spawn_output, trainer)
|
||||
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
|
||||
trainer.optimizers = []
|
||||
return spawn_output.trainer_results
|
||||
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> None:
|
||||
self.spawn(self.new_process, trainer, self.mp_queue)
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
|
||||
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
|
||||
self.__recover_results_in_main_process(spawn_output, trainer)
|
||||
return spawn_output.trainer_results
|
||||
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> None:
|
||||
self.spawn(self.new_process, trainer, self.mp_queue)
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> Any:
|
||||
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
|
||||
self.__recover_results_in_main_process(spawn_output, trainer)
|
||||
return spawn_output.trainer_results
|
||||
|
||||
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]:
|
||||
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
|
||||
"""Spawn processes that run the given function.
|
||||
|
||||
Args:
|
||||
|
@ -191,9 +185,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
|
||||
)
|
||||
|
||||
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
|
||||
self.mp_queue = mp_queue
|
||||
|
||||
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
|
||||
# move the model to the correct device
|
||||
self.model_to_device()
|
||||
|
||||
|
@ -208,28 +200,11 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.barrier()
|
||||
|
||||
results = trainer.run_stage()
|
||||
|
||||
# persist info in ddp_spawn
|
||||
self.__transfer_distrib_spawn_state_on_fit_end(trainer, results)
|
||||
outputs = self.__collect_rank_zero_results(trainer, results)
|
||||
|
||||
# ensure that spawned processes go through teardown before joining
|
||||
trainer._call_teardown_hook()
|
||||
|
||||
def post_dispatch(self, trainer: "pl.Trainer"):
|
||||
# restore main state with best weights
|
||||
best_path = self.mp_queue.get()
|
||||
last_path = self.mp_queue.get()
|
||||
self._results = self.mp_queue.get()
|
||||
# get the `callback_metrics` and set it to the trainer
|
||||
# only in case the user does not override it.
|
||||
# TODO: Remove the if in v1.7
|
||||
if is_overridden("get_from_queue", self.lightning_module):
|
||||
self.lightning_module.get_from_queue(self.mp_queue)
|
||||
else:
|
||||
self.get_from_queue(trainer, self.mp_queue)
|
||||
|
||||
# recover the weights of the processes trained in the children
|
||||
self.__recover_child_process_weights(best_path, last_path)
|
||||
return outputs
|
||||
|
||||
def pre_configure_ddp(self):
|
||||
# if unset, default `find_unused_parameters` `True`
|
||||
|
@ -268,7 +243,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
return None
|
||||
return [self.root_device.index]
|
||||
|
||||
def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None:
|
||||
def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
|
||||
rank_zero_warn("cleaning up ddp environment...")
|
||||
checkpoint_callback = trainer.checkpoint_callback
|
||||
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
|
||||
|
@ -285,28 +260,37 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
|
||||
self.checkpoint_io.save_checkpoint(state_dict, last_path)
|
||||
|
||||
# todo, pass complete checkpoint as state dictionary
|
||||
self.mp_queue.put(best_model_path)
|
||||
self.mp_queue.put(last_path)
|
||||
self.mp_queue.put(results)
|
||||
# adds the `callback_metrics` to the queue
|
||||
# TODO: Remove the if in v1.7
|
||||
extra = _FakeQueue()
|
||||
if is_overridden("add_to_queue", self.lightning_module):
|
||||
self.lightning_module.add_to_queue(self.mp_queue)
|
||||
# TODO: Remove the if in v1.7
|
||||
self.lightning_module.add_to_queue(extra)
|
||||
else:
|
||||
self.add_to_queue(trainer, self.mp_queue)
|
||||
self.add_to_queue(trainer, extra)
|
||||
|
||||
def __recover_child_process_weights(self, best_path, last_path):
|
||||
return _SpawnOutput(best_model_path, last_path, results, extra)
|
||||
|
||||
def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
|
||||
# transfer back the best path to the trainer
|
||||
if self.lightning_module.trainer.checkpoint_callback:
|
||||
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
|
||||
# todo, pass also best score
|
||||
self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path
|
||||
|
||||
# TODO: pass also best score
|
||||
# load last weights
|
||||
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
|
||||
ckpt = self.checkpoint_io.load_checkpoint(last_path, map_location=(lambda storage, loc: storage))
|
||||
if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
|
||||
ckpt = self.checkpoint_io.load_checkpoint(
|
||||
spawn_output.last_path, map_location=(lambda storage, loc: storage)
|
||||
)
|
||||
self.lightning_module.load_state_dict(ckpt)
|
||||
|
||||
# get the `callback_metrics` and set it to the trainer
|
||||
if is_overridden("get_from_queue", self.lightning_module):
|
||||
# only in case the user does not override it.
|
||||
# TODO: Remove the if in v1.7
|
||||
self.lightning_module.get_from_queue(spawn_output.extra)
|
||||
else:
|
||||
self.get_from_queue(trainer, spawn_output.extra)
|
||||
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
if not distributed_available():
|
||||
return
|
||||
|
@ -372,11 +356,12 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
if not self.lightning_module.automatic_optimization:
|
||||
self.model.require_backward_grad_sync = True
|
||||
|
||||
def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
|
||||
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
|
||||
sharing, we cast the data to numpy.
|
||||
|
||||
Args:
|
||||
trainer: reference to the Trainer.
|
||||
queue: the instance of the queue to append the data.
|
||||
"""
|
||||
callback_metrics: dict = apply_to_collection(
|
||||
|
@ -384,11 +369,12 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
) # send as numpy to avoid issues with memory sharing
|
||||
queue.put(callback_metrics)
|
||||
|
||||
def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
|
||||
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
|
||||
we cast back the data to ``torch.Tensor``.
|
||||
|
||||
Args:
|
||||
trainer: reference to the Trainer.
|
||||
queue: the instance of the queue from where to get the data.
|
||||
"""
|
||||
# NOTE: `add_to_queue` needs to be called before
|
||||
|
@ -413,3 +399,23 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.lightning_module.cpu()
|
||||
# clean up memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class _FakeQueue(UserList):
|
||||
"""Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list."""
|
||||
|
||||
def get(self) -> Any:
|
||||
return self.pop(0)
|
||||
|
||||
def put(self, item: Any) -> None:
|
||||
self.append(item)
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self) == 0
|
||||
|
||||
|
||||
class _SpawnOutput(NamedTuple):
|
||||
best_model_path: Optional[_PATH]
|
||||
last_path: Optional[_PATH]
|
||||
trainer_results: Any
|
||||
extra: _FakeQueue
|
||||
|
|
|
@ -12,8 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.queues import SimpleQueue
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
@ -21,7 +20,7 @@ from torch.optim import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
|
@ -115,12 +114,12 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
def post_training_step(self):
|
||||
pass
|
||||
|
||||
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
|
||||
def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]:
|
||||
# Ensure that the scaler points to the correct process group
|
||||
# which is re-initialized in a new process
|
||||
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
|
||||
self._precision_plugin.scaler = ShardedGradScaler()
|
||||
return super().new_process(trainer, mp_queue)
|
||||
return super().new_process(trainer)
|
||||
|
||||
@classmethod
|
||||
def register_plugins(cls, plugin_registry: Dict) -> None:
|
||||
|
|
|
@ -16,7 +16,7 @@ import os
|
|||
import re
|
||||
import time
|
||||
from multiprocessing.queues import SimpleQueue
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
@ -29,7 +29,7 @@ from pytorch_lightning.overrides import LightningDistributedModule
|
|||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters
|
||||
|
@ -123,7 +123,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
os.environ["PT_XLA_DEBUG"] = str(1)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.create_mp_queue()
|
||||
self.start_method = "fork"
|
||||
if not self.setup_optimizers_in_pre_dispatch:
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
|
@ -131,11 +131,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
def _setup_model(self, model: Module) -> Module:
|
||||
return model
|
||||
|
||||
def create_mp_queue(self):
|
||||
self.start_method = "fork"
|
||||
smp = mp.get_context(self.start_method)
|
||||
self.mp_queue = smp.SimpleQueue()
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self) -> Dict[str, int]:
|
||||
return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
|
@ -161,9 +156,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
def set_world_ranks(self, process_idx: int = 0) -> None:
|
||||
pass
|
||||
|
||||
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
|
||||
self.mp_queue = mp_queue
|
||||
|
||||
def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]:
|
||||
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
|
||||
trainer.progress_bar_callback.disable()
|
||||
|
||||
|
@ -181,7 +174,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
|
||||
results = trainer.run_stage()
|
||||
|
||||
self.__transfer_distrib_spawn_state_on_fit_end(trainer, results)
|
||||
outputs = self.__collect_rank_zero_results(trainer, results)
|
||||
|
||||
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
|
||||
self.barrier("end-process")
|
||||
|
@ -192,6 +185,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
|
||||
# ensure that spawned processes go through teardown before joining
|
||||
trainer._call_teardown_hook()
|
||||
return outputs
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
self.model = self.wrapped_model.to(self.root_device)
|
||||
|
@ -200,7 +194,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
if self.is_distributed:
|
||||
rendezvous(name)
|
||||
|
||||
def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None:
|
||||
def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
|
||||
rank_zero_warn("cleaning up tpu spawn environment...")
|
||||
checkpoint_callback = trainer.checkpoint_callback
|
||||
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
|
||||
|
@ -215,17 +209,18 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self.checkpoint_io.save_checkpoint(state_dict, last_path)
|
||||
|
||||
# We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
|
||||
if self.local_rank == 0:
|
||||
# todo, pass complete checkpoint as state dictionary
|
||||
self.mp_queue.put(best_model_path)
|
||||
self.mp_queue.put(last_path)
|
||||
self.mp_queue.put(results)
|
||||
# adds the `callback_metrics` to the queue
|
||||
if self.local_rank != 0:
|
||||
return
|
||||
|
||||
# adds the `callback_metrics` to the queue
|
||||
extra = _FakeQueue()
|
||||
if is_overridden("add_to_queue", self.lightning_module):
|
||||
# TODO: Remove the if in v1.7
|
||||
if is_overridden("add_to_queue", self.lightning_module):
|
||||
self.lightning_module.add_to_queue(self.mp_queue)
|
||||
else:
|
||||
self.add_to_queue(trainer, self.mp_queue)
|
||||
self.lightning_module.add_to_queue(extra)
|
||||
else:
|
||||
self.add_to_queue(trainer, extra)
|
||||
|
||||
return _SpawnOutput(best_model_path, last_path, results, extra)
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
if not self.is_distributed:
|
||||
|
@ -269,7 +264,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
"start_method": self.start_method,
|
||||
}
|
||||
|
||||
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]:
|
||||
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
|
||||
context = mp.get_context(self.start_method or "fork")
|
||||
return_queue = context.SimpleQueue()
|
||||
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
|
||||
|
@ -294,18 +289,18 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self.tpu_global_core_rank = xm.get_ordinal()
|
||||
rank_zero_only.rank = self.global_rank
|
||||
|
||||
def start_training(self, trainer: "pl.Trainer") -> None:
|
||||
def start_training(self, trainer: "pl.Trainer") -> Any:
|
||||
# todo: precision pluging is call in accelerator setup and should be moved
|
||||
if "XLA_USE_BF16" in os.environ:
|
||||
del os.environ["XLA_USE_BF16"]
|
||||
self._clean_logger(trainer)
|
||||
return super().start_training(trainer)
|
||||
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> None:
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
|
||||
self._clean_logger(trainer)
|
||||
return super().start_evaluating(trainer)
|
||||
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> None:
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> Any:
|
||||
self._clean_logger(trainer)
|
||||
return super().start_predicting(trainer)
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin
|
|||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
TBroadcast = TypeVar("TBroadcast")
|
||||
|
||||
|
@ -43,7 +43,6 @@ class TrainingTypePlugin(ABC):
|
|||
self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None
|
||||
) -> None:
|
||||
self._model: Optional[Module] = None
|
||||
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
|
||||
checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO()
|
||||
self._checkpoint_io = checkpoint_io
|
||||
self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin()
|
||||
|
@ -291,18 +290,6 @@ class TrainingTypePlugin(ABC):
|
|||
"""Returns the pure LightningModule without potential wrappers."""
|
||||
return unwrap_lightning_module(self._model) if self._model is not None else None
|
||||
|
||||
@property
|
||||
def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
|
||||
"""Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run.
|
||||
|
||||
The result is
|
||||
cached instead of returned directly, because some plugins require transmitting the results from one
|
||||
multiprocessing context to another in a separate step. For example, the plugins that use the "spawn"
|
||||
start-method send the result to the main process through a
|
||||
`multiprocessing queue (shared memory) <https://pytorch.org/docs/stable/multiprocessing.html>`_.
|
||||
"""
|
||||
return self._results
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
|
||||
torch.cuda.empty_cache()
|
||||
return self.checkpoint_io.load_checkpoint(checkpoint_path)
|
||||
|
@ -315,17 +302,17 @@ class TrainingTypePlugin(ABC):
|
|||
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
|
||||
optimizer.load_state_dict(opt_state)
|
||||
|
||||
def start_training(self, trainer: "pl.Trainer") -> None:
|
||||
def start_training(self, trainer: "pl.Trainer") -> Any:
|
||||
# double dispatch to initiate the training loop
|
||||
self._results = trainer.run_stage()
|
||||
return trainer.run_stage()
|
||||
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> None:
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
|
||||
# double dispatch to initiate the test loop
|
||||
self._results = trainer.run_stage()
|
||||
return trainer.run_stage()
|
||||
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> None:
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> Any:
|
||||
# double dispatch to initiate the predicting loop
|
||||
self._results = trainer.run_stage()
|
||||
return trainer.run_stage()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
return self.model.training_step(*args, **kwargs)
|
||||
|
|
|
@ -1159,9 +1159,8 @@ class Trainer(
|
|||
self.checkpoint_connector.resume_end()
|
||||
|
||||
# dispatch `start_training` or `start_evaluating` or `start_predicting`
|
||||
self._dispatch()
|
||||
results = self._dispatch()
|
||||
|
||||
# plugin will finalized fitting (e.g. ddp_spawn will load trained model)
|
||||
self._post_dispatch()
|
||||
|
||||
# ----------------------------
|
||||
|
@ -1180,7 +1179,7 @@ class Trainer(
|
|||
self.state.status = TrainerStatus.FINISHED
|
||||
self.state.stage = None
|
||||
|
||||
return self.training_type_plugin.results
|
||||
return results
|
||||
|
||||
def _pre_dispatch(self):
|
||||
self.accelerator.pre_dispatch(self)
|
||||
|
@ -1233,13 +1232,13 @@ class Trainer(
|
|||
self.logger_connector.teardown()
|
||||
self.signal_connector.teardown()
|
||||
|
||||
def _dispatch(self):
|
||||
def _dispatch(self) -> Any:
|
||||
if self.evaluating:
|
||||
self.training_type_plugin.start_evaluating(self)
|
||||
return self.training_type_plugin.start_evaluating(self)
|
||||
elif self.predicting:
|
||||
self.training_type_plugin.start_predicting(self)
|
||||
return self.training_type_plugin.start_predicting(self)
|
||||
else:
|
||||
self.training_type_plugin.start_training(self)
|
||||
return self.training_type_plugin.start_training(self)
|
||||
|
||||
def run_stage(self):
|
||||
self.accelerator.dispatch(self)
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
@ -38,11 +39,11 @@ class BoringCallbackDDPSpawnModel(BoringModel):
|
|||
self.log(self.name, self.val)
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
def add_to_queue(self, queue) -> None:
|
||||
queue.put("test_val")
|
||||
return super().add_to_queue(queue)
|
||||
|
||||
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
def get_from_queue(self, queue) -> None:
|
||||
self.test_val = queue.get()
|
||||
return super().get_from_queue(queue)
|
||||
|
||||
|
@ -84,11 +85,11 @@ def test_ddp_spawn_extra_parameters(tmpdir):
|
|||
|
||||
|
||||
class TestDDPSpawnPlugin(DDPSpawnPlugin):
|
||||
def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
def add_to_queue(self, trainer, queue) -> None:
|
||||
queue.put("new_test_val")
|
||||
return super().add_to_queue(trainer, queue)
|
||||
|
||||
def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
|
||||
def get_from_queue(self, trainer: Trainer, queue) -> None:
|
||||
self.new_test_val = queue.get()
|
||||
return super().get_from_queue(trainer, queue)
|
||||
|
||||
|
|
Loading…
Reference in New Issue