2/n Simplify spawn plugins: Spawn immediately (#10896)
This commit is contained in:
parent
3fcfd0214c
commit
a4083df586
|
@ -96,6 +96,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934))
|
||||
|
||||
|
||||
- Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
|
||||
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
|
||||
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts
|
||||
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
|
||||
|
@ -239,7 +245,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed method `training_step`, `test_step`, `validation_step` and `predict_step` from the `Accelerator` ([#10890](https://github.com/PyTorchLightning/pytorch-lightning/pull/10890))
|
||||
|
||||
|
||||
- Removed `HorovodPlugin.start_{training,evaluating,predicting}` hooks ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989))
|
||||
- Removed `TrainingTypePlugin.start_{training,evaluating,predicting}` hooks and the same in all subclasses ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989), [#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
|
||||
|
||||
|
||||
- Removed `Accelerator.on_train_start` ([#10999](https://github.com/PyTorchLightning/pytorch-lightning/pull/10999))
|
||||
|
|
|
@ -132,23 +132,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
|
||||
return {"nprocs": self.num_processes}
|
||||
|
||||
def start_training(self, trainer: "pl.Trainer") -> Any:
|
||||
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
|
||||
self._recover_results_in_main_process(spawn_output, trainer)
|
||||
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
|
||||
trainer.optimizers = []
|
||||
return spawn_output.trainer_results
|
||||
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
|
||||
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
|
||||
self._recover_results_in_main_process(spawn_output, trainer)
|
||||
return spawn_output.trainer_results
|
||||
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> Any:
|
||||
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
|
||||
self._recover_results_in_main_process(spawn_output, trainer)
|
||||
return spawn_output.trainer_results
|
||||
|
||||
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
|
||||
"""Spawn processes that run the given function.
|
||||
|
||||
|
@ -184,7 +167,9 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
|
||||
)
|
||||
|
||||
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
super().pre_dispatch(trainer)
|
||||
|
||||
# move the model to the correct device
|
||||
self.model_to_device()
|
||||
|
||||
|
@ -196,15 +181,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
if trainer_fn == TrainerFn.FITTING:
|
||||
self.configure_ddp()
|
||||
|
||||
self.barrier()
|
||||
|
||||
results = trainer.run_stage()
|
||||
outputs = self._collect_rank_zero_results(trainer, results)
|
||||
|
||||
# ensure that spawned processes go through teardown before joining
|
||||
trainer._call_teardown_hook()
|
||||
return outputs
|
||||
|
||||
def pre_configure_ddp(self):
|
||||
# if unset, default `find_unused_parameters` `True`
|
||||
# Many models require setting this parameter to True, as there are corner cases
|
||||
|
@ -268,7 +244,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
|
||||
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
|
||||
|
||||
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
|
||||
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
|
||||
# transfer back the best path to the trainer
|
||||
if trainer.checkpoint_callback:
|
||||
trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path
|
||||
|
|
|
@ -20,7 +20,7 @@ from torch.optim import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput, DDPSpawnPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
|
@ -114,12 +114,12 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
def post_training_step(self):
|
||||
pass
|
||||
|
||||
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
# Ensure that the scaler points to the correct process group
|
||||
# which is re-initialized in a new process
|
||||
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
|
||||
self._precision_plugin.scaler = ShardedGradScaler()
|
||||
return super().new_process(trainer)
|
||||
return super().pre_dispatch(trainer)
|
||||
|
||||
@classmethod
|
||||
def register_plugins(cls, plugin_registry: Dict) -> None:
|
||||
|
|
|
@ -23,7 +23,6 @@ from torch.nn import Module
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
||||
|
@ -118,10 +117,23 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
return super().connect(model)
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
super().pre_dispatch(trainer)
|
||||
self._move_optimizer_state()
|
||||
if self.debug:
|
||||
os.environ["PT_XLA_DEBUG"] = str(1)
|
||||
|
||||
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
|
||||
trainer.progress_bar_callback.disable()
|
||||
|
||||
shared_params = find_shared_parameters(self.model)
|
||||
self.model_to_device()
|
||||
if is_overridden("on_post_move_to_device", self.lightning_module):
|
||||
self.model.module.on_post_move_to_device()
|
||||
else:
|
||||
set_shared_parameters(self.model.module, shared_params)
|
||||
|
||||
self.setup_optimizers(trainer)
|
||||
self.precision_plugin.connect(self._model, None, None)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.start_method = "fork"
|
||||
super().setup(trainer)
|
||||
|
@ -154,37 +166,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
def set_world_ranks(self, process_idx: int = 0) -> None:
|
||||
pass
|
||||
|
||||
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
|
||||
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
|
||||
trainer.progress_bar_callback.disable()
|
||||
|
||||
shared_params = find_shared_parameters(self.model)
|
||||
self.model_to_device()
|
||||
if is_overridden("on_post_move_to_device", self.lightning_module):
|
||||
self.model.module.on_post_move_to_device()
|
||||
else:
|
||||
set_shared_parameters(self.model.module, shared_params)
|
||||
|
||||
trainer.training_type_plugin.setup_optimizers(trainer)
|
||||
trainer.precision_plugin.connect(self._model, None, None)
|
||||
|
||||
self.barrier("pre-run-stage")
|
||||
|
||||
results = trainer.run_stage()
|
||||
|
||||
outputs = self._collect_rank_zero_results(trainer, results)
|
||||
|
||||
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
|
||||
self.barrier("end-process")
|
||||
|
||||
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
|
||||
if self.local_rank == 0:
|
||||
time.sleep(2)
|
||||
|
||||
# ensure that spawned processes go through teardown before joining
|
||||
trainer._call_teardown_hook()
|
||||
return outputs
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
self.model = self.wrapped_model.to(self.root_device)
|
||||
|
||||
|
@ -215,7 +196,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
if is_overridden("add_to_queue", self.lightning_module):
|
||||
# TODO: Remove the if in v1.7
|
||||
self.lightning_module.add_to_queue(extra)
|
||||
else:
|
||||
self.add_to_queue(trainer, extra)
|
||||
|
||||
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
|
||||
|
@ -263,6 +243,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
}
|
||||
|
||||
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
|
||||
# todo: precision pluging is call in accelerator setup and should be moved
|
||||
if "XLA_USE_BF16" in os.environ:
|
||||
del os.environ["XLA_USE_BF16"]
|
||||
context = mp.get_context(self.start_method or "fork")
|
||||
return_queue = context.SimpleQueue()
|
||||
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
|
||||
|
@ -276,7 +259,10 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
if self.local_rank == 0:
|
||||
return_queue.put(move_data_to_device(result, "cpu"))
|
||||
|
||||
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
|
||||
self.barrier("end-process")
|
||||
|
||||
# Ensure that the rank 0 process is the one exiting last
|
||||
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
|
||||
if self.local_rank == 0:
|
||||
time.sleep(2)
|
||||
|
@ -287,21 +273,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self.tpu_global_core_rank = xm.get_ordinal()
|
||||
rank_zero_only.rank = self.global_rank
|
||||
|
||||
def start_training(self, trainer: "pl.Trainer") -> Any:
|
||||
# todo: precision pluging is call in accelerator setup and should be moved
|
||||
if "XLA_USE_BF16" in os.environ:
|
||||
del os.environ["XLA_USE_BF16"]
|
||||
self._clean_logger(trainer)
|
||||
return super().start_training(trainer)
|
||||
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
|
||||
self._clean_logger(trainer)
|
||||
return super().start_evaluating(trainer)
|
||||
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> Any:
|
||||
self._clean_logger(trainer)
|
||||
return super().start_predicting(trainer)
|
||||
|
||||
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
||||
with self.precision_plugin.val_step_context():
|
||||
return self.model(*args, **kwargs)
|
||||
|
@ -358,9 +329,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
return xm.all_gather(tensor)
|
||||
|
||||
def teardown(self) -> None:
|
||||
# TPU teardown
|
||||
os.environ.pop("PT_XLA_DEBUG", None)
|
||||
self.barrier("teardown")
|
||||
|
||||
@property
|
||||
def should_rank_save_checkpoint(self) -> bool:
|
||||
|
@ -377,13 +346,3 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
@checkpoint_io.setter
|
||||
def checkpoint_io(self, plugin: CheckpointIO) -> None:
|
||||
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")
|
||||
|
||||
@staticmethod
|
||||
def _clean_logger(trainer: "pl.Trainer") -> None:
|
||||
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
|
||||
for logger in loggers:
|
||||
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
|
||||
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
|
||||
# we want to make sure these are closed before we spawn our own threads.
|
||||
# assuming nothing else references the experiment object, python should instantly `__del__` it.
|
||||
logger._experiment = None
|
||||
|
|
|
@ -307,18 +307,6 @@ class TrainingTypePlugin(ABC):
|
|||
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
|
||||
optimizer.load_state_dict(opt_state)
|
||||
|
||||
def start_training(self, trainer: "pl.Trainer") -> Any:
|
||||
# double dispatch to initiate the training loop
|
||||
return trainer.run_stage()
|
||||
|
||||
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
|
||||
# double dispatch to initiate the test loop
|
||||
return trainer.run_stage()
|
||||
|
||||
def start_predicting(self, trainer: "pl.Trainer") -> Any:
|
||||
# double dispatch to initiate the predicting loop
|
||||
return trainer.run_stage()
|
||||
|
||||
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
|
||||
"""The actual training step.
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ from pytorch_lightning.plugins import (
|
|||
TrainingTypePlugin,
|
||||
)
|
||||
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput
|
||||
from pytorch_lightning.profiler import (
|
||||
AdvancedProfiler,
|
||||
BaseProfiler,
|
||||
|
@ -673,6 +674,11 @@ class Trainer(
|
|||
**kwargs: keyword arguments to be passed to `trainer_fn`
|
||||
"""
|
||||
try:
|
||||
if isinstance(self.training_type_plugin, DDPSpawnPlugin):
|
||||
spawn_output: _SpawnOutput = self.training_type_plugin.spawn(trainer_fn, *args, **kwargs)
|
||||
self.training_type_plugin._recover_results_in_main_process(spawn_output, self)
|
||||
return spawn_output.trainer_results
|
||||
else:
|
||||
return trainer_fn(*args, **kwargs)
|
||||
# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
|
||||
except KeyboardInterrupt as exception:
|
||||
|
@ -721,6 +727,7 @@ class Trainer(
|
|||
|
||||
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
|
||||
"""
|
||||
self.training_type_plugin.model = model
|
||||
self._call_and_handle_interrupt(
|
||||
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
|
||||
)
|
||||
|
@ -756,10 +763,11 @@ class Trainer(
|
|||
|
||||
# TODO: ckpt_path only in v1.7
|
||||
ckpt_path = ckpt_path or self.resume_from_checkpoint
|
||||
self._run(model, ckpt_path=ckpt_path)
|
||||
results = self._run(model, ckpt_path=ckpt_path)
|
||||
|
||||
assert self.state.stopped
|
||||
self.training = False
|
||||
return results
|
||||
|
||||
def validate(
|
||||
self,
|
||||
|
@ -793,6 +801,7 @@ class Trainer(
|
|||
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc.
|
||||
The length of the list corresponds to the number of validation dataloaders used.
|
||||
"""
|
||||
self.training_type_plugin.model = model or self.lightning_module
|
||||
return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule)
|
||||
|
||||
def _validate_impl(
|
||||
|
@ -876,6 +885,7 @@ class Trainer(
|
|||
:meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc.
|
||||
The length of the list corresponds to the number of test dataloaders used.
|
||||
"""
|
||||
self.training_type_plugin.model = model or self.lightning_module
|
||||
return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
|
||||
|
||||
def _test_impl(
|
||||
|
@ -958,6 +968,7 @@ class Trainer(
|
|||
Returns:
|
||||
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
|
||||
"""
|
||||
self.training_type_plugin.model = model or self.lightning_module
|
||||
return self._call_and_handle_interrupt(
|
||||
self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
|
||||
)
|
||||
|
@ -1113,25 +1124,21 @@ class Trainer(
|
|||
Lightning internal flow looks like this:
|
||||
{Trainer.fit} or {Trainer.test} or {Trainer.predict} ||
|
||||
| ||
|
||||
create accelerator ||
|
||||
spawn processes ||
|
||||
{self.accelerator.setup_environment} ||
|
||||
| ||
|
||||
{self._dispatch} ||
|
||||
| || LIGHTNING
|
||||
{self.training_type_plugin.start_training} ||
|
||||
or {self.training_type_plugin.start_evaluating} ||
|
||||
or {self.training_type_plugin.start_predicting} || FLOW
|
||||
setup accelerator ||
|
||||
and strategy || LIGHTNING
|
||||
| ||
|
||||
{self.run_stage} ||
|
||||
| || DIRECTION
|
||||
{self._run_train} ||
|
||||
{self.run_stage} || FLOW
|
||||
| ||
|
||||
{self._run_train} || DIRECTION
|
||||
or {self._run_evaluate} ||
|
||||
or {self._run_predict} ||
|
||||
| ||
|
||||
results \/
|
||||
This is used to guide readers to the core loops: train, test, predict.
|
||||
{self._run_predict} is the simplest to understand, use `Go to Definition` to read it :)
|
||||
Search for `start_training` or `start_evaluating` or `start_predicting` in
|
||||
`pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions.
|
||||
"""
|
||||
|
||||
# ----------------------------
|
||||
|
@ -1147,7 +1154,7 @@ class Trainer(
|
|||
self._call_callback_hooks("on_fit_start")
|
||||
self._call_lightning_module_hook("on_fit_start")
|
||||
|
||||
# plugin will setup fitting (e.g. ddp will launch child processes)
|
||||
# plugin will move model to device
|
||||
self._pre_dispatch()
|
||||
|
||||
if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
|
||||
|
@ -1158,10 +1165,8 @@ class Trainer(
|
|||
|
||||
self.checkpoint_connector.resume_end()
|
||||
|
||||
# dispatch `start_training` or `start_evaluating` or `start_predicting`
|
||||
results = self._dispatch()
|
||||
|
||||
self._post_dispatch()
|
||||
results = self.run_stage()
|
||||
self._teardown()
|
||||
|
||||
# ----------------------------
|
||||
# POST-Training CLEAN UP
|
||||
|
@ -1171,15 +1176,15 @@ class Trainer(
|
|||
self._call_callback_hooks("on_fit_end")
|
||||
self._call_lightning_module_hook("on_fit_end")
|
||||
|
||||
# teardown if necessary (similar calls for spawn plugins are excluded as they have
|
||||
# been included at the end of `new_process` functions)
|
||||
if not isinstance(self.training_type_plugin, DDPSpawnPlugin):
|
||||
self._call_teardown_hook()
|
||||
|
||||
if self.state.status != TrainerStatus.INTERRUPTED:
|
||||
self.state.status = TrainerStatus.FINISHED
|
||||
self.state.stage = None
|
||||
|
||||
if isinstance(self.training_type_plugin, DDPSpawnPlugin):
|
||||
results = self.training_type_plugin._collect_rank_zero_results(self, results)
|
||||
|
||||
return results
|
||||
|
||||
def _pre_dispatch(self):
|
||||
|
@ -1223,9 +1228,9 @@ class Trainer(
|
|||
self.logger.log_graph(self.lightning_module)
|
||||
self.logger.save()
|
||||
|
||||
def _post_dispatch(self):
|
||||
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
|
||||
# which need to happen before.
|
||||
def _teardown(self):
|
||||
"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
|
||||
Callback; those are handled by :meth:`_call_teardown_hook`."""
|
||||
self.training_type_plugin.post_dispatch(self)
|
||||
self.accelerator.teardown()
|
||||
self._data_connector.teardown()
|
||||
|
@ -1233,15 +1238,8 @@ class Trainer(
|
|||
self.logger_connector.teardown()
|
||||
self.signal_connector.teardown()
|
||||
|
||||
def _dispatch(self) -> Any:
|
||||
if self.evaluating:
|
||||
return self.training_type_plugin.start_evaluating(self)
|
||||
elif self.predicting:
|
||||
return self.training_type_plugin.start_predicting(self)
|
||||
else:
|
||||
return self.training_type_plugin.start_training(self)
|
||||
|
||||
def run_stage(self):
|
||||
self.training_type_plugin.barrier("run-stage")
|
||||
self.training_type_plugin.dispatch(self)
|
||||
self.__setup_profiler()
|
||||
|
||||
|
|
|
@ -80,7 +80,6 @@ def test_ddp_spawn_extra_parameters(tmpdir):
|
|||
val_name: str = "val_acc"
|
||||
model = BoringCallbackDDPSpawnModel(val_name, val)
|
||||
dm = BoringDataModule()
|
||||
with pytest.deprecated_call(match="add_to_queue` method was deprecated in v1.5"):
|
||||
trainer.fit(model, datamodule=dm)
|
||||
assert trainer.callback_metrics[val_name] == torch.tensor(val)
|
||||
assert model.test_val == "test_val"
|
||||
|
@ -107,7 +106,6 @@ def test_ddp_spawn_add_get_queue(tmpdir):
|
|||
val_name: str = "val_acc"
|
||||
model = BoringCallbackDDPSpawnModel(val_name, val)
|
||||
dm = BoringDataModule()
|
||||
with pytest.deprecated_call(match="add_to_queue` method was deprecated in v1.5"):
|
||||
trainer.fit(model, datamodule=dm)
|
||||
assert trainer.callback_metrics[val_name] == torch.tensor(val)
|
||||
assert ddp_spawn_plugin.new_test_val == "new_test_val"
|
||||
|
|
|
@ -50,7 +50,7 @@ from pytorch_lightning.trainer.states import TrainerFn
|
|||
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_8
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
from tests.helpers import BoringModel, RandomDataset
|
||||
from tests.helpers.boring_model import RandomIterableDataset, RandomIterableDatasetWithLen
|
||||
|
@ -61,6 +61,11 @@ from tests.helpers.simple_models import ClassificationModel
|
|||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
from torch.multiprocessing import ProcessRaisedException
|
||||
else:
|
||||
ProcessRaisedException = Exception
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url_ckpt", [True, False])
|
||||
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
|
||||
|
@ -1419,7 +1424,7 @@ def predict(
|
|||
callbacks=[cb, cb_1] if use_callbacks else [],
|
||||
)
|
||||
if strategy == "ddp_spawn":
|
||||
with pytest.raises(MisconfigurationException):
|
||||
with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"):
|
||||
trainer.predict(model, datamodule=dm, return_predictions=True)
|
||||
|
||||
if datamodule:
|
||||
|
@ -1517,6 +1522,7 @@ def test_index_batch_sampler_wrapper_with_iterable_dataset(dataset_cls, tmpdir):
|
|||
assert len(predictions) == 8
|
||||
|
||||
|
||||
@pytest.mark.skipif(_IS_WINDOWS and not _TORCH_GREATER_EQUAL_1_8, reason="torch.distributed support required")
|
||||
@patch("torch.cuda.device_count", return_value=2)
|
||||
@patch("torch.cuda.is_available", return_value=True)
|
||||
@pytest.mark.parametrize("accelerator", ("cpu", "gpu"))
|
||||
|
@ -1525,7 +1531,7 @@ def test_spawn_predict_return_predictions(_, __, accelerator):
|
|||
model = BoringModel()
|
||||
trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True)
|
||||
assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
|
||||
with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"):
|
||||
with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"):
|
||||
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue