diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8743068a6f..b5636dab02 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -112,6 +112,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
+ * Removed the `TrainingTypePlugin.pre_dispatch()` method and merged it with `TrainingTypePlugin.setup()` ([#11137](https://github.com/PyTorchLightning/pytorch-lightning/pull/11137))
- Changed profiler to index and display the names of the hooks with a new pattern []. ([#11026](https://github.com/PyTorchLightning/pytorch-lightning/pull/11026))
diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py
index 9066cc540e..ce9d3a3c64 100644
--- a/pytorch_lightning/plugins/training_type/ddp.py
+++ b/pytorch_lightning/plugins/training_type/ddp.py
@@ -151,6 +151,24 @@ class DDPPlugin(ParallelPlugin):
self.setup_distributed()
super().setup_environment()
+ def setup(self, trainer: "pl.Trainer") -> None:
+ super().setup(trainer)
+ # share ddp pids to all processes
+ self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
+ if self._should_run_deadlock_detection():
+ self._share_information_to_prevent_deadlock()
+
+ # move the model to the correct device
+ self.model_to_device()
+
+ if self.sync_batchnorm:
+ self.model = self.configure_sync_batchnorm(self.model)
+
+ # skip wrapping the model if we are not fitting as no gradients need to be exchanged
+ trainer_fn = self.lightning_module.trainer.state.fn
+ if trainer_fn == TrainerFn.FITTING:
+ self.configure_ddp()
+
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
@@ -341,24 +359,6 @@ class DDPPlugin(ParallelPlugin):
return None
return [self.root_device.index]
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
- super().pre_dispatch(trainer)
- # share ddp pids to all processes
- self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
- if self._should_run_deadlock_detection():
- self._share_information_to_prevent_deadlock()
-
- # move the model to the correct device
- self.model_to_device()
-
- if self.sync_batchnorm:
- self.model = self.configure_sync_batchnorm(self.model)
-
- # skip wrapping the model if we are not fitting as no gradients need to be exchanged
- trainer_fn = self.lightning_module.trainer.state.fn
- if trainer_fn == TrainerFn.FITTING:
- self.configure_ddp()
-
def barrier(self, *args, **kwargs) -> None:
if not distributed_available():
return
diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py
index 0468177e4a..5c25b5078c 100644
--- a/pytorch_lightning/plugins/training_type/ddp_spawn.py
+++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py
@@ -120,6 +120,17 @@ class DDPSpawnPlugin(ParallelPlugin):
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
super().setup(trainer)
+ # move the model to the correct device
+ self.model_to_device()
+
+ if self.sync_batchnorm:
+ self.model = self.configure_sync_batchnorm(self.model)
+
+ # skip wrapping the model if we are not fitting as no gradients need to be exchanged
+ trainer_fn = self.lightning_module.trainer.state.fn
+ if trainer_fn == TrainerFn.FITTING:
+ self.configure_ddp()
+
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
@@ -170,20 +181,6 @@ class DDPSpawnPlugin(ParallelPlugin):
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
)
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
- super().pre_dispatch(trainer)
-
- # move the model to the correct device
- self.model_to_device()
-
- if self.sync_batchnorm:
- self.model = self.configure_sync_batchnorm(self.model)
-
- # skip wrapping the model if we are not fitting as no gradients need to be exchanged
- trainer_fn = self.lightning_module.trainer.state.fn
- if trainer_fn == TrainerFn.FITTING:
- self.configure_ddp()
-
def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
# Many models require setting this parameter to True, as there are corner cases
diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py
index a2da329562..67343ff780 100644
--- a/pytorch_lightning/plugins/training_type/deepspeed.py
+++ b/pytorch_lightning/plugins/training_type/deepspeed.py
@@ -346,6 +346,14 @@ class DeepSpeedPlugin(DDPPlugin):
self._format_config()
self._config_initialized = True
+ def setup(self, trainer: "pl.Trainer") -> None:
+ self.accelerator.setup(trainer)
+ self.setup_optimizers(trainer)
+ self.setup_precision_plugin()
+ self._move_optimizer_state()
+ self.init_deepspeed()
+ self.barrier()
+
def _init_deepspeed_distributed(self) -> None:
if platform.system() != "Windows":
# do not set env variables on windows, allow deepspeed to control setup
@@ -368,11 +376,6 @@ class DeepSpeedPlugin(DDPPlugin):
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return True
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
- self._move_optimizer_state()
- self.init_deepspeed()
- self.barrier()
-
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.
diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py
index 2022dcbd33..be2a3671ca 100644
--- a/pytorch_lightning/plugins/training_type/fully_sharded.py
+++ b/pytorch_lightning/plugins/training_type/fully_sharded.py
@@ -129,6 +129,19 @@ class DDPFullyShardedStrategy(DDPPlugin):
)
super().setup_distributed()
+ def setup(self, trainer: "pl.Trainer") -> None:
+ self.accelerator.setup(trainer)
+ self.setup_optimizers(trainer)
+ self.setup_precision_plugin()
+ self._move_optimizer_state()
+
+ if self.sync_batchnorm:
+ self.model = self.configure_sync_batchnorm(self.model)
+
+ self.configure_ddp()
+ self.barrier()
+ self.setup_optimizers(trainer)
+
@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
precision = self.precision_plugin.precision
@@ -163,14 +176,6 @@ class DDPFullyShardedStrategy(DDPPlugin):
# setup optimizers after fully sharded has wrapped the lightning module
self.setup_optimizers(self.lightning_module.trainer)
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
- self._move_optimizer_state()
- if self.sync_batchnorm:
- self.model = self.configure_sync_batchnorm(self.model)
- self.configure_ddp()
- self.barrier()
- self.setup_optimizers(trainer)
-
def model_to_device(self) -> None:
# ensure we update the device type in the lightning module
self.lightning_module.to(self.root_device)
diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py
index 858d290b20..53fe65a94a 100644
--- a/pytorch_lightning/plugins/training_type/horovod.py
+++ b/pytorch_lightning/plugins/training_type/horovod.py
@@ -79,10 +79,9 @@ class HorovodPlugin(ParallelPlugin):
def setup(self, trainer: "pl.Trainer") -> None:
self.model_to_device()
+
super().setup(trainer)
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
- super().pre_dispatch(trainer)
self._exit_stack = ExitStack()
self._exit_stack.__enter__()
diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py
index 9a1ddaf9b3..d8e1097212 100644
--- a/pytorch_lightning/plugins/training_type/ipu.py
+++ b/pytorch_lightning/plugins/training_type/ipu.py
@@ -126,14 +126,6 @@ class IPUPlugin(ParallelPlugin):
super().setup(trainer)
- def setup_optimizers(self, trainer: "pl.Trainer") -> None:
- super().setup_optimizers(trainer)
-
- if len(self.optimizers) > 1:
- raise MisconfigurationException("IPUs currently only support one optimizer.")
-
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
- super().pre_dispatch(trainer)
model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision)
self.model = model
@@ -164,6 +156,12 @@ class IPUPlugin(ParallelPlugin):
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
self.poptorch_models[RunningStage.PREDICTING] = model
+ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
+ super().setup_optimizers(trainer)
+
+ if len(self.optimizers) > 1:
+ raise MisconfigurationException("IPUs currently only support one optimizer.")
+
@property
def replication_factor(self) -> int:
if not self.lightning_module or not self.poptorch_models:
diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py
index 34bb0b01f4..3d23504ea5 100644
--- a/pytorch_lightning/plugins/training_type/single_tpu.py
+++ b/pytorch_lightning/plugins/training_type/single_tpu.py
@@ -64,11 +64,6 @@ class SingleTPUPlugin(SingleDevicePlugin):
super().setup(trainer)
- def model_to_device(self) -> None:
- self.model.to(self.root_device)
-
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
- super().pre_dispatch(trainer)
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)
@@ -78,6 +73,9 @@ class SingleTPUPlugin(SingleDevicePlugin):
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
+ def model_to_device(self) -> None:
+ self.model.to(self.root_device)
+
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py
index 226c1cdca7..bbf545f441 100644
--- a/pytorch_lightning/plugins/training_type/tpu_spawn.py
+++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py
@@ -120,8 +120,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
return super().connect(model)
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
+ def setup(self, trainer: "pl.Trainer") -> None:
+ self.start_method = "fork"
+ self.accelerator.setup(trainer)
+ self.setup_optimizers(trainer)
+ self.setup_precision_plugin()
self._move_optimizer_state()
+
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)
diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py
index 0e48aab404..cd45a38712 100644
--- a/pytorch_lightning/plugins/training_type/training_type_plugin.py
+++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py
@@ -121,6 +121,7 @@ class Strategy(ABC):
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
+ self._move_optimizer_state()
def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator."""
@@ -490,10 +491,6 @@ class Strategy(ABC):
"""Called in the training loop before anything happens for that batch."""
pass
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
- """Hook to do something before the training/evaluation/prediction starts."""
- self._move_optimizer_state()
-
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.precision_plugin.dispatch(trainer)
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 246f53d280..dc19aa206a 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -1111,7 +1111,6 @@ class Trainer(
self._restore_modules_and_callbacks(ckpt_path)
self._call_configure_sharded_model() # allow user to setup in model sharded environment
- self.training_type_plugin.setup(self)
# ----------------------------
# INSPECT THE CORE LOOPS
@@ -1145,13 +1144,15 @@ class Trainer(
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()
+ # strategy will configure model and move it to the device
+ self.training_type_plugin.setup(self)
+
# hook
if self.state.fn == TrainerFn.FITTING:
self._call_callback_hooks("on_fit_start")
self._call_lightning_module_hook("on_fit_start")
- # plugin will move model to device
- self._pre_dispatch()
+ self._log_hyperparams()
if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
self._restore_modules_and_callbacks(ckpt_path)
@@ -1183,10 +1184,6 @@ class Trainer(
return results
- def _pre_dispatch(self):
- self.training_type_plugin.pre_dispatch(self)
- self._log_hyperparams()
-
def _log_hyperparams(self) -> None:
# log hyper-parameters
hparams_initial = None
diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py
index 2ef234b1ff..de0030fdb9 100644
--- a/tests/accelerators/test_cpu.py
+++ b/tests/accelerators/test_cpu.py
@@ -28,18 +28,18 @@ def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatc
dispatch is called."""
class TestPlugin(SingleDevicePlugin):
- predispatched_called = False
+ setup_called = False
- def pre_dispatch(self, trainer: "pl.Trainer") -> None:
- super().pre_dispatch(trainer)
- self.predispatched_called = True
+ def setup(self, trainer: "pl.Trainer") -> None:
+ super().setup(trainer)
+ self.setup_called = True
@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return restore_after_pre_dispatch
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
- assert self.predispatched_called == restore_after_pre_dispatch
+ assert self.setup_called == restore_after_pre_dispatch
return super().load_checkpoint(checkpoint_path)
model = BoringModel()
@@ -60,5 +60,5 @@ def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatc
trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True)
trainer.fit(model, ckpt_path=checkpoint_path)
for func in (trainer.test, trainer.validate, trainer.predict):
- plugin.predispatched_called = False
+ plugin.setup_called = False
func(model, ckpt_path=checkpoint_path)
diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py
index d404db964a..7af34035d9 100644
--- a/tests/accelerators/test_ipu.py
+++ b/tests/accelerators/test_ipu.py
@@ -436,7 +436,7 @@ def test_replication_factor(tmpdir):
plugin.model = model
model.trainer = trainer
trainer.state.fn = TrainerFn.FITTING
- trainer.training_type_plugin.pre_dispatch(trainer)
+ trainer.training_type_plugin.setup(trainer)
trainer.state.stage = RunningStage.TRAINING
assert trainer.training_type_plugin.replication_factor == 8
@@ -450,7 +450,7 @@ def test_replication_factor(tmpdir):
):
trainer.state.fn = fn
trainer.state.stage = stage
- trainer.training_type_plugin.pre_dispatch(trainer)
+ trainer.training_type_plugin.setup(trainer)
assert trainer.training_type_plugin.replication_factor == 7
@@ -585,7 +585,7 @@ def test_poptorch_models_at_different_stages(tmpdir):
trainer.optimizers = model.configure_optimizers()[0]
trainer.state.fn = TrainerFn.FITTING
- trainer.training_type_plugin.pre_dispatch(trainer)
+ trainer.training_type_plugin.setup(trainer)
assert list(trainer.training_type_plugin.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING]
for fn, stage in (
@@ -595,7 +595,7 @@ def test_poptorch_models_at_different_stages(tmpdir):
):
trainer.state.fn = fn
trainer.state.stage = stage
- trainer.training_type_plugin.pre_dispatch(trainer)
+ trainer.training_type_plugin.setup(trainer)
assert list(trainer.training_type_plugin.poptorch_models) == [stage]
diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py
index 43321ad2ae..5842152278 100644
--- a/tests/models/test_hooks.py
+++ b/tests/models/test_hooks.py
@@ -516,13 +516,9 @@ def _run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization):
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="configure_sharded_model"),
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
- # DeepSpeed skips initializing optimizers here as they are handled via config
- *([dict(name="configure_optimizers")] if kwargs.get("strategy") != "deepspeed" else []),
+ dict(name="configure_optimizers"),
dict(name="Callback.on_fit_start", args=(trainer, model)),
dict(name="on_fit_start"),
- # TODO: explore whether DeepSpeed can have the same flow for optimizers
- # DeepSpeed did not find any optimizer in the config so they are loaded here
- *([dict(name="configure_optimizers")] if kwargs.get("strategy") == "deepspeed" else []),
dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)),
dict(name="on_pretrain_routine_start"),
dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)),
diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py
index e99474efd8..58633beb01 100644
--- a/tests/plugins/test_ddp_plugin.py
+++ b/tests/plugins/test_ddp_plugin.py
@@ -110,11 +110,10 @@ def test_ddp_configure_ddp():
# test wrap the model if fitting
trainer.state.fn = TrainerFn.FITTING
trainer.training_type_plugin.connect(model)
- trainer.training_type_plugin.setup_environment()
- trainer.training_type_plugin.setup(trainer)
trainer.lightning_module.trainer = trainer
+ trainer.training_type_plugin.setup_environment()
assert isinstance(trainer.model, LightningModule)
- trainer._pre_dispatch()
+ trainer.training_type_plugin.setup(trainer)
# in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel
assert isinstance(trainer.model, DistributedDataParallel)
@@ -123,10 +122,10 @@ def test_ddp_configure_ddp():
strategy=ddp_plugin,
)
# test do not wrap the model if trainerFN is not fitting
+ trainer.state.fn = TrainerFn.VALIDATING
trainer.training_type_plugin.connect(model)
+ trainer.lightning_module.trainer = trainer
trainer.training_type_plugin.setup_environment()
trainer.training_type_plugin.setup(trainer)
- trainer.lightning_module.trainer = trainer
- trainer._pre_dispatch()
# in DDPPlugin configure_ddp(), model are still LightningModule
assert isinstance(trainer.model, LightningModule)
diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py
index 36c266343b..f62adbddc6 100644
--- a/tests/trainer/logging_/test_distributed_logging.py
+++ b/tests/trainer/logging_/test_distributed_logging.py
@@ -104,8 +104,8 @@ def test_first_logger_call_in_subprocess(tmpdir):
"""
class LoggerCallsObserver(Callback):
- def on_fit_start(self, trainer, pl_module):
- # this hook is executed directly before Trainer.pre_dispatch
+ def setup(self, trainer, pl_module, stage):
+ # this hook is executed after Strategy has setup the environment
# logger should not write any logs until this point
assert not trainer.logger.method_calls
assert not os.listdir(trainer.logger.save_dir)