1/n Simplify spawn plugins: Simplify handling of multiprocessing queue (#10034)

Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Adrian Wälchli 2021-12-02 11:30:44 +01:00 committed by GitHub
parent 541b983b90
commit 98cb7e8790
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 122 additions and 125 deletions

View File

@ -80,7 +80,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649))
-
- The `DDPSpawnPlugin` no longer overrides the `post_dispatch` plugin hook ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
- The `LightningModule.{add_to_queue,get_from_queue}` hooks no longer get a `torch.multiprocessing.SimpleQueue` and instead receive a list based queue ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
### Deprecated
@ -188,6 +192,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))
- Removed the property `TrainingTypePlugin.results` and corresponding properties in subclasses ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
- Removed the `mp_queue` attribute from `DDPSpawnPlugin` and `TPUSpawnPlugin` ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
- Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849))

View File

@ -1917,7 +1917,7 @@ class LightningModule(
)
return get_model_size_mb(self)
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.
@ -1931,7 +1931,7 @@ class LightningModule(
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
self.trainer.training_type_plugin.add_to_queue(self.trainer, queue)
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.

View File

@ -14,8 +14,9 @@
import logging
import os
import re
from collections import UserList
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union
import numpy as np
import torch
@ -45,7 +46,7 @@ from pytorch_lightning.utilities.distributed import (
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
@ -80,7 +81,6 @@ class DDPSpawnPlugin(ParallelPlugin):
self.sync_batchnorm = False
self._ddp_kwargs = kwargs
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.mp_queue = None
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
@ -101,15 +101,6 @@ class DDPSpawnPlugin(ParallelPlugin):
def local_rank(self) -> int:
return self._local_rank
def __getstate__(self):
"""Makes this plugin pickleable without destroying the queue in the current process."""
state = self.__dict__.copy()
state["mp_queue"] = None
return state
def __setstate__(self, state):
self.__dict__ = state
@property
def root_device(self):
return self.parallel_devices[self.local_rank]
@ -125,9 +116,6 @@ class DDPSpawnPlugin(ParallelPlugin):
def setup(self, trainer: "pl.Trainer") -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
# pass in a state q
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()
super().setup(trainer)
def _setup_model(self, model: Module) -> DistributedDataParallel:
@ -145,18 +133,24 @@ class DDPSpawnPlugin(ParallelPlugin):
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {"nprocs": self.num_processes}
def start_training(self, trainer: "pl.Trainer") -> None:
self.spawn(self.new_process, trainer, self.mp_queue)
def start_training(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self.__recover_results_in_main_process(spawn_output, trainer)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []
return spawn_output.trainer_results
def start_evaluating(self, trainer: "pl.Trainer") -> None:
self.spawn(self.new_process, trainer, self.mp_queue)
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self.__recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results
def start_predicting(self, trainer: "pl.Trainer") -> None:
self.spawn(self.new_process, trainer, self.mp_queue)
def start_predicting(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self.__recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]:
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
"""Spawn processes that run the given function.
Args:
@ -191,9 +185,7 @@ class DDPSpawnPlugin(ParallelPlugin):
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
)
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
self.mp_queue = mp_queue
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
# move the model to the correct device
self.model_to_device()
@ -208,28 +200,11 @@ class DDPSpawnPlugin(ParallelPlugin):
self.barrier()
results = trainer.run_stage()
# persist info in ddp_spawn
self.__transfer_distrib_spawn_state_on_fit_end(trainer, results)
outputs = self.__collect_rank_zero_results(trainer, results)
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
def post_dispatch(self, trainer: "pl.Trainer"):
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
self._results = self.mp_queue.get()
# get the `callback_metrics` and set it to the trainer
# only in case the user does not override it.
# TODO: Remove the if in v1.7
if is_overridden("get_from_queue", self.lightning_module):
self.lightning_module.get_from_queue(self.mp_queue)
else:
self.get_from_queue(trainer, self.mp_queue)
# recover the weights of the processes trained in the children
self.__recover_child_process_weights(best_path, last_path)
return outputs
def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
@ -268,7 +243,7 @@ class DDPSpawnPlugin(ParallelPlugin):
return None
return [self.root_device.index]
def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None:
def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_warn("cleaning up ddp environment...")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
@ -285,28 +260,37 @@ class DDPSpawnPlugin(ParallelPlugin):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
# adds the `callback_metrics` to the queue
# TODO: Remove the if in v1.7
extra = _FakeQueue()
if is_overridden("add_to_queue", self.lightning_module):
self.lightning_module.add_to_queue(self.mp_queue)
# TODO: Remove the if in v1.7
self.lightning_module.add_to_queue(extra)
else:
self.add_to_queue(trainer, self.mp_queue)
self.add_to_queue(trainer, extra)
def __recover_child_process_weights(self, best_path, last_path):
return _SpawnOutput(best_model_path, last_path, results, extra)
def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
# transfer back the best path to the trainer
if self.lightning_module.trainer.checkpoint_callback:
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score
self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path
# TODO: pass also best score
# load last weights
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
ckpt = self.checkpoint_io.load_checkpoint(last_path, map_location=(lambda storage, loc: storage))
if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
ckpt = self.checkpoint_io.load_checkpoint(
spawn_output.last_path, map_location=(lambda storage, loc: storage)
)
self.lightning_module.load_state_dict(ckpt)
# get the `callback_metrics` and set it to the trainer
if is_overridden("get_from_queue", self.lightning_module):
# only in case the user does not override it.
# TODO: Remove the if in v1.7
self.lightning_module.get_from_queue(spawn_output.extra)
else:
self.get_from_queue(trainer, spawn_output.extra)
def barrier(self, *args, **kwargs) -> None:
if not distributed_available():
return
@ -372,11 +356,12 @@ class DDPSpawnPlugin(ParallelPlugin):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True
def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.
Args:
trainer: reference to the Trainer.
queue: the instance of the queue to append the data.
"""
callback_metrics: dict = apply_to_collection(
@ -384,11 +369,12 @@ class DDPSpawnPlugin(ParallelPlugin):
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)
def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.
Args:
trainer: reference to the Trainer.
queue: the instance of the queue from where to get the data.
"""
# NOTE: `add_to_queue` needs to be called before
@ -413,3 +399,23 @@ class DDPSpawnPlugin(ParallelPlugin):
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
class _FakeQueue(UserList):
"""Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list."""
def get(self) -> Any:
return self.pop(0)
def put(self, item: Any) -> None:
self.append(item)
def empty(self) -> bool:
return len(self) == 0
class _SpawnOutput(NamedTuple):
best_model_path: Optional[_PATH]
last_path: Optional[_PATH]
trainer_results: Any
extra: _FakeQueue

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from multiprocessing.queues import SimpleQueue
from typing import Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple
import torch
from torch.nn import Module
@ -21,7 +20,7 @@ from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType
@ -115,12 +114,12 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
def post_training_step(self):
pass
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]:
# Ensure that the scaler points to the correct process group
# which is re-initialized in a new process
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
self._precision_plugin.scaler = ShardedGradScaler()
return super().new_process(trainer, mp_queue)
return super().new_process(trainer)
@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:

View File

@ -16,7 +16,7 @@ import os
import re
import time
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.multiprocessing as mp
@ -29,7 +29,7 @@ from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters
@ -123,7 +123,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
os.environ["PT_XLA_DEBUG"] = str(1)
def setup(self, trainer: "pl.Trainer") -> None:
self.create_mp_queue()
self.start_method = "fork"
if not self.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)
self.setup_precision_plugin()
@ -131,11 +131,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
def _setup_model(self, model: Module) -> Module:
return model
def create_mp_queue(self):
self.start_method = "fork"
smp = mp.get_context(self.start_method)
self.mp_queue = smp.SimpleQueue()
@property
def distributed_sampler_kwargs(self) -> Dict[str, int]:
return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
@ -161,9 +156,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
def set_world_ranks(self, process_idx: int = 0) -> None:
pass
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
self.mp_queue = mp_queue
def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]:
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()
@ -181,7 +174,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
results = trainer.run_stage()
self.__transfer_distrib_spawn_state_on_fit_end(trainer, results)
outputs = self.__collect_rank_zero_results(trainer, results)
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self.barrier("end-process")
@ -192,6 +185,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
return outputs
def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)
@ -200,7 +194,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
if self.is_distributed:
rendezvous(name)
def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None:
def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_warn("cleaning up tpu spawn environment...")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
@ -215,17 +209,18 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
self.checkpoint_io.save_checkpoint(state_dict, last_path)
# We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
if self.local_rank == 0:
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
# adds the `callback_metrics` to the queue
if self.local_rank != 0:
return
# adds the `callback_metrics` to the queue
extra = _FakeQueue()
if is_overridden("add_to_queue", self.lightning_module):
# TODO: Remove the if in v1.7
if is_overridden("add_to_queue", self.lightning_module):
self.lightning_module.add_to_queue(self.mp_queue)
else:
self.add_to_queue(trainer, self.mp_queue)
self.lightning_module.add_to_queue(extra)
else:
self.add_to_queue(trainer, extra)
return _SpawnOutput(best_model_path, last_path, results, extra)
def broadcast(self, obj: object, src: int = 0) -> object:
if not self.is_distributed:
@ -269,7 +264,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
"start_method": self.start_method,
}
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]:
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
context = mp.get_context(self.start_method or "fork")
return_queue = context.SimpleQueue()
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
@ -294,18 +289,18 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
self.tpu_global_core_rank = xm.get_ordinal()
rank_zero_only.rank = self.global_rank
def start_training(self, trainer: "pl.Trainer") -> None:
def start_training(self, trainer: "pl.Trainer") -> Any:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
self._clean_logger(trainer)
return super().start_training(trainer)
def start_evaluating(self, trainer: "pl.Trainer") -> None:
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
self._clean_logger(trainer)
return super().start_evaluating(trainer)
def start_predicting(self, trainer: "pl.Trainer") -> None:
def start_predicting(self, trainer: "pl.Trainer") -> Any:
self._clean_logger(trainer)
return super().start_predicting(trainer)

View File

@ -30,7 +30,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT
from pytorch_lightning.utilities.types import _PATH
TBroadcast = TypeVar("TBroadcast")
@ -43,7 +43,6 @@ class TrainingTypePlugin(ABC):
self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None
) -> None:
self._model: Optional[Module] = None
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO()
self._checkpoint_io = checkpoint_io
self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin()
@ -291,18 +290,6 @@ class TrainingTypePlugin(ABC):
"""Returns the pure LightningModule without potential wrappers."""
return unwrap_lightning_module(self._model) if self._model is not None else None
@property
def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
"""Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run.
The result is
cached instead of returned directly, because some plugins require transmitting the results from one
multiprocessing context to another in a separate step. For example, the plugins that use the "spawn"
start-method send the result to the main process through a
`multiprocessing queue (shared memory) <https://pytorch.org/docs/stable/multiprocessing.html>`_.
"""
return self._results
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()
return self.checkpoint_io.load_checkpoint(checkpoint_path)
@ -315,17 +302,17 @@ class TrainingTypePlugin(ABC):
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
def start_training(self, trainer: "pl.Trainer") -> None:
def start_training(self, trainer: "pl.Trainer") -> Any:
# double dispatch to initiate the training loop
self._results = trainer.run_stage()
return trainer.run_stage()
def start_evaluating(self, trainer: "pl.Trainer") -> None:
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
# double dispatch to initiate the test loop
self._results = trainer.run_stage()
return trainer.run_stage()
def start_predicting(self, trainer: "pl.Trainer") -> None:
def start_predicting(self, trainer: "pl.Trainer") -> Any:
# double dispatch to initiate the predicting loop
self._results = trainer.run_stage()
return trainer.run_stage()
def training_step(self, *args, **kwargs):
return self.model.training_step(*args, **kwargs)

View File

@ -1159,9 +1159,8 @@ class Trainer(
self.checkpoint_connector.resume_end()
# dispatch `start_training` or `start_evaluating` or `start_predicting`
self._dispatch()
results = self._dispatch()
# plugin will finalized fitting (e.g. ddp_spawn will load trained model)
self._post_dispatch()
# ----------------------------
@ -1180,7 +1179,7 @@ class Trainer(
self.state.status = TrainerStatus.FINISHED
self.state.stage = None
return self.training_type_plugin.results
return results
def _pre_dispatch(self):
self.accelerator.pre_dispatch(self)
@ -1233,13 +1232,13 @@ class Trainer(
self.logger_connector.teardown()
self.signal_connector.teardown()
def _dispatch(self):
def _dispatch(self) -> Any:
if self.evaluating:
self.training_type_plugin.start_evaluating(self)
return self.training_type_plugin.start_evaluating(self)
elif self.predicting:
self.training_type_plugin.start_predicting(self)
return self.training_type_plugin.start_predicting(self)
else:
self.training_type_plugin.start_training(self)
return self.training_type_plugin.start_training(self)
def run_stage(self):
self.accelerator.dispatch(self)

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
@ -38,11 +39,11 @@ class BoringCallbackDDPSpawnModel(BoringModel):
self.log(self.name, self.val)
return super().validation_step(batch, batch_idx)
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
def add_to_queue(self, queue) -> None:
queue.put("test_val")
return super().add_to_queue(queue)
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
def get_from_queue(self, queue) -> None:
self.test_val = queue.get()
return super().get_from_queue(queue)
@ -84,11 +85,11 @@ def test_ddp_spawn_extra_parameters(tmpdir):
class TestDDPSpawnPlugin(DDPSpawnPlugin):
def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
def add_to_queue(self, trainer, queue) -> None:
queue.put("new_test_val")
return super().add_to_queue(trainer, queue)
def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
def get_from_queue(self, trainer: Trainer, queue) -> None:
self.new_test_val = queue.get()
return super().get_from_queue(trainer, queue)