From 98cb7e87908a2413ef2f31ca2bfb794f1eab3aa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 11:30:44 +0100 Subject: [PATCH] 1/n Simplify spawn plugins: Simplify handling of multiprocessing queue (#10034) Co-authored-by: thomas chaton --- CHANGELOG.md | 12 +- pytorch_lightning/core/lightning.py | 4 +- .../plugins/training_type/ddp_spawn.py | 124 +++++++++--------- .../plugins/training_type/sharded_spawn.py | 9 +- .../plugins/training_type/tpu_spawn.py | 49 ++++--- .../training_type/training_type_plugin.py | 27 +--- pytorch_lightning/trainer/trainer.py | 13 +- tests/plugins/test_ddp_spawn_plugin.py | 9 +- 8 files changed, 122 insertions(+), 125 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa9a4f4f23..e965454486 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,7 +80,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649)) -- +- The `DDPSpawnPlugin` no longer overrides the `post_dispatch` plugin hook ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + + +- The `LightningModule.{add_to_queue,get_from_queue}` hooks no longer get a `torch.multiprocessing.SimpleQueue` and instead receive a list based queue ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + ### Deprecated @@ -188,6 +192,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867)) +- Removed the property `TrainingTypePlugin.results` and corresponding properties in subclasses ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + + +- Removed the `mp_queue` attribute from `DDPSpawnPlugin` and `TPUSpawnPlugin` ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + + - Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 45109e5050..e02c9d32ec 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1917,7 +1917,7 @@ class LightningModule( ) return get_model_size_mb(self) - def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1931,7 +1931,7 @@ class LightningModule( if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 512665e18c..563f39a1f0 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -14,8 +14,9 @@ import logging import os import re +from collections import UserList from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union import numpy as np import torch @@ -45,7 +46,7 @@ from pytorch_lightning.utilities.distributed import ( from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -80,7 +81,6 @@ class DDPSpawnPlugin(ParallelPlugin): self.sync_batchnorm = False self._ddp_kwargs = kwargs self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 - self.mp_queue = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper @@ -101,15 +101,6 @@ class DDPSpawnPlugin(ParallelPlugin): def local_rank(self) -> int: return self._local_rank - def __getstate__(self): - """Makes this plugin pickleable without destroying the queue in the current process.""" - state = self.__dict__.copy() - state["mp_queue"] = None - return state - - def __setstate__(self, state): - self.__dict__ = state - @property def root_device(self): return self.parallel_devices[self.local_rank] @@ -125,9 +116,6 @@ class DDPSpawnPlugin(ParallelPlugin): def setup(self, trainer: "pl.Trainer") -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - # pass in a state q - smp = mp.get_context("spawn") - self.mp_queue = smp.SimpleQueue() super().setup(trainer) def _setup_model(self, model: Module) -> DistributedDataParallel: @@ -145,18 +133,24 @@ class DDPSpawnPlugin(ParallelPlugin): def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def start_training(self, trainer: "pl.Trainer") -> None: - self.spawn(self.new_process, trainer, self.mp_queue) + def start_training(self, trainer: "pl.Trainer") -> Any: + spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) + self.__recover_results_in_main_process(spawn_output, trainer) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] + return spawn_output.trainer_results - def start_evaluating(self, trainer: "pl.Trainer") -> None: - self.spawn(self.new_process, trainer, self.mp_queue) + def start_evaluating(self, trainer: "pl.Trainer") -> Any: + spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) + self.__recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results - def start_predicting(self, trainer: "pl.Trainer") -> None: - self.spawn(self.new_process, trainer, self.mp_queue) + def start_predicting(self, trainer: "pl.Trainer") -> Any: + spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) + self.__recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: """Spawn processes that run the given function. Args: @@ -191,9 +185,7 @@ class DDPSpawnPlugin(ParallelPlugin): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: - self.mp_queue = mp_queue - + def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: # move the model to the correct device self.model_to_device() @@ -208,28 +200,11 @@ class DDPSpawnPlugin(ParallelPlugin): self.barrier() results = trainer.run_stage() - - # persist info in ddp_spawn - self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) + outputs = self.__collect_rank_zero_results(trainer, results) # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() - - def post_dispatch(self, trainer: "pl.Trainer"): - # restore main state with best weights - best_path = self.mp_queue.get() - last_path = self.mp_queue.get() - self._results = self.mp_queue.get() - # get the `callback_metrics` and set it to the trainer - # only in case the user does not override it. - # TODO: Remove the if in v1.7 - if is_overridden("get_from_queue", self.lightning_module): - self.lightning_module.get_from_queue(self.mp_queue) - else: - self.get_from_queue(trainer, self.mp_queue) - - # recover the weights of the processes trained in the children - self.__recover_child_process_weights(best_path, last_path) + return outputs def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` @@ -268,7 +243,7 @@ class DDPSpawnPlugin(ParallelPlugin): return None return [self.root_device.index] - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up ddp environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -285,28 +260,37 @@ class DDPSpawnPlugin(ParallelPlugin): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) self.checkpoint_io.save_checkpoint(state_dict, last_path) - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) # adds the `callback_metrics` to the queue - # TODO: Remove the if in v1.7 + extra = _FakeQueue() if is_overridden("add_to_queue", self.lightning_module): - self.lightning_module.add_to_queue(self.mp_queue) + # TODO: Remove the if in v1.7 + self.lightning_module.add_to_queue(extra) else: - self.add_to_queue(trainer, self.mp_queue) + self.add_to_queue(trainer, extra) - def __recover_child_process_weights(self, best_path, last_path): + return _SpawnOutput(best_model_path, last_path, results, extra) + + def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: - self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score + self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path + # TODO: pass also best score # load last weights - if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: - ckpt = self.checkpoint_io.load_checkpoint(last_path, map_location=(lambda storage, loc: storage)) + if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + ckpt = self.checkpoint_io.load_checkpoint( + spawn_output.last_path, map_location=(lambda storage, loc: storage) + ) self.lightning_module.load_state_dict(ckpt) + # get the `callback_metrics` and set it to the trainer + if is_overridden("get_from_queue", self.lightning_module): + # only in case the user does not override it. + # TODO: Remove the if in v1.7 + self.lightning_module.get_from_queue(spawn_output.extra) + else: + self.get_from_queue(trainer, spawn_output.extra) + def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return @@ -372,11 +356,12 @@ class DDPSpawnPlugin(ParallelPlugin): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None: + def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. Args: + trainer: reference to the Trainer. queue: the instance of the queue to append the data. """ callback_metrics: dict = apply_to_collection( @@ -384,11 +369,12 @@ class DDPSpawnPlugin(ParallelPlugin): ) # send as numpy to avoid issues with memory sharing queue.put(callback_metrics) - def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None: + def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. Args: + trainer: reference to the Trainer. queue: the instance of the queue from where to get the data. """ # NOTE: `add_to_queue` needs to be called before @@ -413,3 +399,23 @@ class DDPSpawnPlugin(ParallelPlugin): self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache() + + +class _FakeQueue(UserList): + """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" + + def get(self) -> Any: + return self.pop(0) + + def put(self, item: Any) -> None: + self.append(item) + + def empty(self) -> bool: + return len(self) == 0 + + +class _SpawnOutput(NamedTuple): + best_model_path: Optional[_PATH] + last_path: Optional[_PATH] + trainer_results: Any + extra: _FakeQueue diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 70bc3bc16c..5e10155cc3 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from multiprocessing.queues import SimpleQueue -from typing import Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple import torch from torch.nn import Module @@ -21,7 +20,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType @@ -115,12 +114,12 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): self._precision_plugin.scaler = ShardedGradScaler() - return super().new_process(trainer, mp_queue) + return super().new_process(trainer) @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 79f4797eda..02a8998592 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -16,7 +16,7 @@ import os import re import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.multiprocessing as mp @@ -29,7 +29,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters @@ -123,7 +123,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): os.environ["PT_XLA_DEBUG"] = str(1) def setup(self, trainer: "pl.Trainer") -> None: - self.create_mp_queue() + self.start_method = "fork" if not self.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) self.setup_precision_plugin() @@ -131,11 +131,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def _setup_model(self, model: Module) -> Module: return model - def create_mp_queue(self): - self.start_method = "fork" - smp = mp.get_context(self.start_method) - self.mp_queue = smp.SimpleQueue() - @property def distributed_sampler_kwargs(self) -> Dict[str, int]: return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) @@ -161,9 +156,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: - self.mp_queue = mp_queue - + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() @@ -181,7 +174,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): results = trainer.run_stage() - self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) + outputs = self.__collect_rank_zero_results(trainer, results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") @@ -192,6 +185,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() + return outputs def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device) @@ -200,7 +194,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): if self.is_distributed: rendezvous(name) - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up tpu spawn environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -215,17 +209,18 @@ class TPUSpawnPlugin(DDPSpawnPlugin): self.checkpoint_io.save_checkpoint(state_dict, last_path) # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training - if self.local_rank == 0: - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - # adds the `callback_metrics` to the queue + if self.local_rank != 0: + return + + # adds the `callback_metrics` to the queue + extra = _FakeQueue() + if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 - if is_overridden("add_to_queue", self.lightning_module): - self.lightning_module.add_to_queue(self.mp_queue) - else: - self.add_to_queue(trainer, self.mp_queue) + self.lightning_module.add_to_queue(extra) + else: + self.add_to_queue(trainer, extra) + + return _SpawnOutput(best_model_path, last_path, results, extra) def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: @@ -269,7 +264,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): "start_method": self.start_method, } - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) @@ -294,18 +289,18 @@ class TPUSpawnPlugin(DDPSpawnPlugin): self.tpu_global_core_rank = xm.get_ordinal() rank_zero_only.rank = self.global_rank - def start_training(self, trainer: "pl.Trainer") -> None: + def start_training(self, trainer: "pl.Trainer") -> Any: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] self._clean_logger(trainer) return super().start_training(trainer) - def start_evaluating(self, trainer: "pl.Trainer") -> None: + def start_evaluating(self, trainer: "pl.Trainer") -> Any: self._clean_logger(trainer) return super().start_evaluating(trainer) - def start_predicting(self, trainer: "pl.Trainer") -> None: + def start_predicting(self, trainer: "pl.Trainer") -> Any: self._clean_logger(trainer) return super().start_predicting(trainer) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 9334df4b18..da70b71d5e 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -30,7 +30,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT +from pytorch_lightning.utilities.types import _PATH TBroadcast = TypeVar("TBroadcast") @@ -43,7 +43,6 @@ class TrainingTypePlugin(ABC): self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None ) -> None: self._model: Optional[Module] = None - self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() @@ -291,18 +290,6 @@ class TrainingTypePlugin(ABC): """Returns the pure LightningModule without potential wrappers.""" return unwrap_lightning_module(self._model) if self._model is not None else None - @property - def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: - """Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run. - - The result is - cached instead of returned directly, because some plugins require transmitting the results from one - multiprocessing context to another in a separate step. For example, the plugins that use the "spawn" - start-method send the result to the main process through a - `multiprocessing queue (shared memory) `_. - """ - return self._results - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) @@ -315,17 +302,17 @@ class TrainingTypePlugin(ABC): for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) - def start_training(self, trainer: "pl.Trainer") -> None: + def start_training(self, trainer: "pl.Trainer") -> Any: # double dispatch to initiate the training loop - self._results = trainer.run_stage() + return trainer.run_stage() - def start_evaluating(self, trainer: "pl.Trainer") -> None: + def start_evaluating(self, trainer: "pl.Trainer") -> Any: # double dispatch to initiate the test loop - self._results = trainer.run_stage() + return trainer.run_stage() - def start_predicting(self, trainer: "pl.Trainer") -> None: + def start_predicting(self, trainer: "pl.Trainer") -> Any: # double dispatch to initiate the predicting loop - self._results = trainer.run_stage() + return trainer.run_stage() def training_step(self, *args, **kwargs): return self.model.training_step(*args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b88d8e4ff5..2dd0ddfcc4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1159,9 +1159,8 @@ class Trainer( self.checkpoint_connector.resume_end() # dispatch `start_training` or `start_evaluating` or `start_predicting` - self._dispatch() + results = self._dispatch() - # plugin will finalized fitting (e.g. ddp_spawn will load trained model) self._post_dispatch() # ---------------------------- @@ -1180,7 +1179,7 @@ class Trainer( self.state.status = TrainerStatus.FINISHED self.state.stage = None - return self.training_type_plugin.results + return results def _pre_dispatch(self): self.accelerator.pre_dispatch(self) @@ -1233,13 +1232,13 @@ class Trainer( self.logger_connector.teardown() self.signal_connector.teardown() - def _dispatch(self): + def _dispatch(self) -> Any: if self.evaluating: - self.training_type_plugin.start_evaluating(self) + return self.training_type_plugin.start_evaluating(self) elif self.predicting: - self.training_type_plugin.start_predicting(self) + return self.training_type_plugin.start_predicting(self) else: - self.training_type_plugin.start_training(self) + return self.training_type_plugin.start_training(self) def run_stage(self): self.accelerator.dispatch(self) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index db61711ab5..6ea265a4bb 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import pytest import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -38,11 +39,11 @@ class BoringCallbackDDPSpawnModel(BoringModel): self.log(self.name, self.val) return super().validation_step(batch, batch_idx) - def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + def add_to_queue(self, queue) -> None: queue.put("test_val") return super().add_to_queue(queue) - def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + def get_from_queue(self, queue) -> None: self.test_val = queue.get() return super().get_from_queue(queue) @@ -84,11 +85,11 @@ def test_ddp_spawn_extra_parameters(tmpdir): class TestDDPSpawnPlugin(DDPSpawnPlugin): - def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: + def add_to_queue(self, trainer, queue) -> None: queue.put("new_test_val") return super().add_to_queue(trainer, queue) - def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: + def get_from_queue(self, trainer: Trainer, queue) -> None: self.new_test_val = queue.get() return super().get_from_queue(trainer, queue)