2/n Simplify spawn plugins: Spawn immediately (#10896)

This commit is contained in:
Adrian Wälchli 2021-12-09 19:56:24 +01:00 committed by GitHub
parent 3fcfd0214c
commit a4083df586
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 77 additions and 146 deletions

View File

@ -96,6 +96,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934))
- Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
### Deprecated
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
@ -239,7 +245,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed method `training_step`, `test_step`, `validation_step` and `predict_step` from the `Accelerator` ([#10890](https://github.com/PyTorchLightning/pytorch-lightning/pull/10890))
- Removed `HorovodPlugin.start_{training,evaluating,predicting}` hooks ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989))
- Removed `TrainingTypePlugin.start_{training,evaluating,predicting}` hooks and the same in all subclasses ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989), [#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
- Removed `Accelerator.on_train_start` ([#10999](https://github.com/PyTorchLightning/pytorch-lightning/pull/10999))

View File

@ -132,23 +132,6 @@ class DDPSpawnPlugin(ParallelPlugin):
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {"nprocs": self.num_processes}
def start_training(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self._recover_results_in_main_process(spawn_output, trainer)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []
return spawn_output.trainer_results
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self._recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results
def start_predicting(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self._recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
"""Spawn processes that run the given function.
@ -184,7 +167,9 @@ class DDPSpawnPlugin(ParallelPlugin):
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
)
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
# move the model to the correct device
self.model_to_device()
@ -196,15 +181,6 @@ class DDPSpawnPlugin(ParallelPlugin):
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()
self.barrier()
results = trainer.run_stage()
outputs = self._collect_rank_zero_results(trainer, results)
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
return outputs
def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
# Many models require setting this parameter to True, as there are corner cases
@ -268,7 +244,7 @@ class DDPSpawnPlugin(ParallelPlugin):
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
# transfer back the best path to the trainer
if trainer.checkpoint_callback:
trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path

View File

@ -20,7 +20,7 @@ from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput, DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType
@ -114,12 +114,12 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
def post_training_step(self):
pass
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
# Ensure that the scaler points to the correct process group
# which is re-initialized in a new process
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
self._precision_plugin.scaler = ShardedGradScaler()
return super().new_process(trainer)
return super().pre_dispatch(trainer)
@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:

View File

@ -23,7 +23,6 @@ from torch.nn import Module
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
@ -118,10 +117,23 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
return super().connect(model)
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
self._move_optimizer_state()
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()
shared_params = find_shared_parameters(self.model)
self.model_to_device()
if is_overridden("on_post_move_to_device", self.lightning_module):
self.model.module.on_post_move_to_device()
else:
set_shared_parameters(self.model.module, shared_params)
self.setup_optimizers(trainer)
self.precision_plugin.connect(self._model, None, None)
def setup(self, trainer: "pl.Trainer") -> None:
self.start_method = "fork"
super().setup(trainer)
@ -154,37 +166,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
def set_world_ranks(self, process_idx: int = 0) -> None:
pass
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()
shared_params = find_shared_parameters(self.model)
self.model_to_device()
if is_overridden("on_post_move_to_device", self.lightning_module):
self.model.module.on_post_move_to_device()
else:
set_shared_parameters(self.model.module, shared_params)
trainer.training_type_plugin.setup_optimizers(trainer)
trainer.precision_plugin.connect(self._model, None, None)
self.barrier("pre-run-stage")
results = trainer.run_stage()
outputs = self._collect_rank_zero_results(trainer, results)
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self.barrier("end-process")
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if self.local_rank == 0:
time.sleep(2)
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
return outputs
def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)
@ -215,8 +196,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
if is_overridden("add_to_queue", self.lightning_module):
# TODO: Remove the if in v1.7
self.lightning_module.add_to_queue(extra)
else:
self.add_to_queue(trainer, extra)
self.add_to_queue(trainer, extra)
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
@ -263,6 +243,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
}
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
context = mp.get_context(self.start_method or "fork")
return_queue = context.SimpleQueue()
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
@ -276,7 +259,10 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
if self.local_rank == 0:
return_queue.put(move_data_to_device(result, "cpu"))
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self.barrier("end-process")
# Ensure that the rank 0 process is the one exiting last
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if self.local_rank == 0:
time.sleep(2)
@ -287,21 +273,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
self.tpu_global_core_rank = xm.get_ordinal()
rank_zero_only.rank = self.global_rank
def start_training(self, trainer: "pl.Trainer") -> Any:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
self._clean_logger(trainer)
return super().start_training(trainer)
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
self._clean_logger(trainer)
return super().start_evaluating(trainer)
def start_predicting(self, trainer: "pl.Trainer") -> Any:
self._clean_logger(trainer)
return super().start_predicting(trainer)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
return self.model(*args, **kwargs)
@ -358,9 +329,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
return xm.all_gather(tensor)
def teardown(self) -> None:
# TPU teardown
os.environ.pop("PT_XLA_DEBUG", None)
self.barrier("teardown")
@property
def should_rank_save_checkpoint(self) -> bool:
@ -377,13 +346,3 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
@checkpoint_io.setter
def checkpoint_io(self, plugin: CheckpointIO) -> None:
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")
@staticmethod
def _clean_logger(trainer: "pl.Trainer") -> None:
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
for logger in loggers:
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
# we want to make sure these are closed before we spawn our own threads.
# assuming nothing else references the experiment object, python should instantly `__del__` it.
logger._experiment = None

View File

@ -307,18 +307,6 @@ class TrainingTypePlugin(ABC):
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
def start_training(self, trainer: "pl.Trainer") -> Any:
# double dispatch to initiate the training loop
return trainer.run_stage()
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
# double dispatch to initiate the test loop
return trainer.run_stage()
def start_predicting(self, trainer: "pl.Trainer") -> Any:
# double dispatch to initiate the predicting loop
return trainer.run_stage()
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
"""The actual training step.

View File

@ -49,6 +49,7 @@ from pytorch_lightning.plugins import (
TrainingTypePlugin,
)
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput
from pytorch_lightning.profiler import (
AdvancedProfiler,
BaseProfiler,
@ -673,7 +674,12 @@ class Trainer(
**kwargs: keyword arguments to be passed to `trainer_fn`
"""
try:
return trainer_fn(*args, **kwargs)
if isinstance(self.training_type_plugin, DDPSpawnPlugin):
spawn_output: _SpawnOutput = self.training_type_plugin.spawn(trainer_fn, *args, **kwargs)
self.training_type_plugin._recover_results_in_main_process(spawn_output, self)
return spawn_output.trainer_results
else:
return trainer_fn(*args, **kwargs)
# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
@ -721,6 +727,7 @@ class Trainer(
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
"""
self.training_type_plugin.model = model
self._call_and_handle_interrupt(
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)
@ -756,10 +763,11 @@ class Trainer(
# TODO: ckpt_path only in v1.7
ckpt_path = ckpt_path or self.resume_from_checkpoint
self._run(model, ckpt_path=ckpt_path)
results = self._run(model, ckpt_path=ckpt_path)
assert self.state.stopped
self.training = False
return results
def validate(
self,
@ -793,6 +801,7 @@ class Trainer(
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc.
The length of the list corresponds to the number of validation dataloaders used.
"""
self.training_type_plugin.model = model or self.lightning_module
return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule)
def _validate_impl(
@ -876,6 +885,7 @@ class Trainer(
:meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc.
The length of the list corresponds to the number of test dataloaders used.
"""
self.training_type_plugin.model = model or self.lightning_module
return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
def _test_impl(
@ -958,6 +968,7 @@ class Trainer(
Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
"""
self.training_type_plugin.model = model or self.lightning_module
return self._call_and_handle_interrupt(
self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
)
@ -1113,25 +1124,21 @@ class Trainer(
Lightning internal flow looks like this:
{Trainer.fit} or {Trainer.test} or {Trainer.predict} ||
| ||
create accelerator ||
spawn processes ||
{self.accelerator.setup_environment} ||
| ||
{self._dispatch} ||
| || LIGHTNING
{self.training_type_plugin.start_training} ||
or {self.training_type_plugin.start_evaluating} ||
or {self.training_type_plugin.start_predicting} || FLOW
setup accelerator ||
and strategy || LIGHTNING
| ||
{self.run_stage} ||
| || DIRECTION
{self._run_train} ||
{self.run_stage} || FLOW
| ||
{self._run_train} || DIRECTION
or {self._run_evaluate} ||
or {self._run_predict} ||
| ||
results \/
This is used to guide readers to the core loops: train, test, predict.
{self._run_predict} is the simplest to understand, use `Go to Definition` to read it :)
Search for `start_training` or `start_evaluating` or `start_predicting` in
`pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions.
"""
# ----------------------------
@ -1147,7 +1154,7 @@ class Trainer(
self._call_callback_hooks("on_fit_start")
self._call_lightning_module_hook("on_fit_start")
# plugin will setup fitting (e.g. ddp will launch child processes)
# plugin will move model to device
self._pre_dispatch()
if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
@ -1158,10 +1165,8 @@ class Trainer(
self.checkpoint_connector.resume_end()
# dispatch `start_training` or `start_evaluating` or `start_predicting`
results = self._dispatch()
self._post_dispatch()
results = self.run_stage()
self._teardown()
# ----------------------------
# POST-Training CLEAN UP
@ -1171,15 +1176,15 @@ class Trainer(
self._call_callback_hooks("on_fit_end")
self._call_lightning_module_hook("on_fit_end")
# teardown if necessary (similar calls for spawn plugins are excluded as they have
# been included at the end of `new_process` functions)
if not isinstance(self.training_type_plugin, DDPSpawnPlugin):
self._call_teardown_hook()
self._call_teardown_hook()
if self.state.status != TrainerStatus.INTERRUPTED:
self.state.status = TrainerStatus.FINISHED
self.state.stage = None
if isinstance(self.training_type_plugin, DDPSpawnPlugin):
results = self.training_type_plugin._collect_rank_zero_results(self, results)
return results
def _pre_dispatch(self):
@ -1223,9 +1228,9 @@ class Trainer(
self.logger.log_graph(self.lightning_module)
self.logger.save()
def _post_dispatch(self):
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
def _teardown(self):
"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
Callback; those are handled by :meth:`_call_teardown_hook`."""
self.training_type_plugin.post_dispatch(self)
self.accelerator.teardown()
self._data_connector.teardown()
@ -1233,15 +1238,8 @@ class Trainer(
self.logger_connector.teardown()
self.signal_connector.teardown()
def _dispatch(self) -> Any:
if self.evaluating:
return self.training_type_plugin.start_evaluating(self)
elif self.predicting:
return self.training_type_plugin.start_predicting(self)
else:
return self.training_type_plugin.start_training(self)
def run_stage(self):
self.training_type_plugin.barrier("run-stage")
self.training_type_plugin.dispatch(self)
self.__setup_profiler()

View File

@ -80,8 +80,7 @@ def test_ddp_spawn_extra_parameters(tmpdir):
val_name: str = "val_acc"
model = BoringCallbackDDPSpawnModel(val_name, val)
dm = BoringDataModule()
with pytest.deprecated_call(match="add_to_queue` method was deprecated in v1.5"):
trainer.fit(model, datamodule=dm)
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert model.test_val == "test_val"
@ -107,8 +106,7 @@ def test_ddp_spawn_add_get_queue(tmpdir):
val_name: str = "val_acc"
model = BoringCallbackDDPSpawnModel(val_name, val)
dm = BoringDataModule()
with pytest.deprecated_call(match="add_to_queue` method was deprecated in v1.5"):
trainer.fit(model, datamodule=dm)
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert ddp_spawn_plugin.new_test_val == "new_test_val"

View File

@ -50,7 +50,7 @@ from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.seed import seed_everything
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.boring_model import RandomIterableDataset, RandomIterableDatasetWithLen
@ -61,6 +61,11 @@ from tests.helpers.simple_models import ClassificationModel
if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
if _TORCH_GREATER_EQUAL_1_8:
from torch.multiprocessing import ProcessRaisedException
else:
ProcessRaisedException = Exception
@pytest.mark.parametrize("url_ckpt", [True, False])
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
@ -1419,7 +1424,7 @@ def predict(
callbacks=[cb, cb_1] if use_callbacks else [],
)
if strategy == "ddp_spawn":
with pytest.raises(MisconfigurationException):
with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"):
trainer.predict(model, datamodule=dm, return_predictions=True)
if datamodule:
@ -1517,6 +1522,7 @@ def test_index_batch_sampler_wrapper_with_iterable_dataset(dataset_cls, tmpdir):
assert len(predictions) == 8
@pytest.mark.skipif(_IS_WINDOWS and not _TORCH_GREATER_EQUAL_1_8, reason="torch.distributed support required")
@patch("torch.cuda.device_count", return_value=2)
@patch("torch.cuda.is_available", return_value=True)
@pytest.mark.parametrize("accelerator", ("cpu", "gpu"))
@ -1525,7 +1531,7 @@ def test_spawn_predict_return_predictions(_, __, accelerator):
model = BoringModel()
trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True)
assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"):
with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"):
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True)