From f5c2881b68341fa840886f09ff3b1d3f068783e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 20 Dec 2021 17:41:22 +0100 Subject: [PATCH] 3/n Simplify spawn plugins: Merge `pre_dispatch` and `setup` logic (#11137) --- CHANGELOG.md | 1 + .../plugins/training_type/ddp.py | 36 +++++++++---------- .../plugins/training_type/ddp_spawn.py | 25 ++++++------- .../plugins/training_type/deepspeed.py | 13 ++++--- .../plugins/training_type/fully_sharded.py | 21 ++++++----- .../plugins/training_type/horovod.py | 3 +- .../plugins/training_type/ipu.py | 14 ++++---- .../plugins/training_type/single_tpu.py | 8 ++--- .../plugins/training_type/tpu_spawn.py | 7 +++- .../training_type/training_type_plugin.py | 5 +-- pytorch_lightning/trainer/trainer.py | 11 +++--- tests/accelerators/test_cpu.py | 12 +++---- tests/accelerators/test_ipu.py | 8 ++--- tests/models/test_hooks.py | 6 +--- tests/plugins/test_ddp_plugin.py | 9 +++-- .../logging_/test_distributed_logging.py | 4 +-- 16 files changed, 89 insertions(+), 94 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8743068a6f..b5636dab02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -112,6 +112,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` * The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts * Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8) + * Removed the `TrainingTypePlugin.pre_dispatch()` method and merged it with `TrainingTypePlugin.setup()` ([#11137](https://github.com/PyTorchLightning/pytorch-lightning/pull/11137)) - Changed profiler to index and display the names of the hooks with a new pattern []. ([#11026](https://github.com/PyTorchLightning/pytorch-lightning/pull/11026)) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 9066cc540e..ce9d3a3c64 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -151,6 +151,24 @@ class DDPPlugin(ParallelPlugin): self.setup_distributed() super().setup_environment() + def setup(self, trainer: "pl.Trainer") -> None: + super().setup(trainer) + # share ddp pids to all processes + self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts) + if self._should_run_deadlock_detection(): + self._share_information_to_prevent_deadlock() + + # move the model to the correct device + self.model_to_device() + + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + + # skip wrapping the model if we are not fitting as no gradients need to be exchanged + trainer_fn = self.lightning_module.trainer.state.fn + if trainer_fn == TrainerFn.FITTING: + self.configure_ddp() + def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) @@ -341,24 +359,6 @@ class DDPPlugin(ParallelPlugin): return None return [self.root_device.index] - def pre_dispatch(self, trainer: "pl.Trainer") -> None: - super().pre_dispatch(trainer) - # share ddp pids to all processes - self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts) - if self._should_run_deadlock_detection(): - self._share_information_to_prevent_deadlock() - - # move the model to the correct device - self.model_to_device() - - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - - # skip wrapping the model if we are not fitting as no gradients need to be exchanged - trainer_fn = self.lightning_module.trainer.state.fn - if trainer_fn == TrainerFn.FITTING: - self.configure_ddp() - def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 0468177e4a..5c25b5078c 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -120,6 +120,17 @@ class DDPSpawnPlugin(ParallelPlugin): os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) super().setup(trainer) + # move the model to the correct device + self.model_to_device() + + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + + # skip wrapping the model if we are not fitting as no gradients need to be exchanged + trainer_fn = self.lightning_module.trainer.state.fn + if trainer_fn == TrainerFn.FITTING: + self.configure_ddp() + def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) @@ -170,20 +181,6 @@ class DDPSpawnPlugin(ParallelPlugin): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def pre_dispatch(self, trainer: "pl.Trainer") -> None: - super().pre_dispatch(trainer) - - # move the model to the correct device - self.model_to_device() - - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - - # skip wrapping the model if we are not fitting as no gradients need to be exchanged - trainer_fn = self.lightning_module.trainer.state.fn - if trainer_fn == TrainerFn.FITTING: - self.configure_ddp() - def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index a2da329562..67343ff780 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -346,6 +346,14 @@ class DeepSpeedPlugin(DDPPlugin): self._format_config() self._config_initialized = True + def setup(self, trainer: "pl.Trainer") -> None: + self.accelerator.setup(trainer) + self.setup_optimizers(trainer) + self.setup_precision_plugin() + self._move_optimizer_state() + self.init_deepspeed() + self.barrier() + def _init_deepspeed_distributed(self) -> None: if platform.system() != "Windows": # do not set env variables on windows, allow deepspeed to control setup @@ -368,11 +376,6 @@ class DeepSpeedPlugin(DDPPlugin): def restore_checkpoint_after_pre_dispatch(self) -> bool: return True - def pre_dispatch(self, trainer: "pl.Trainer") -> None: - self._move_optimizer_state() - self.init_deepspeed() - self.barrier() - def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together. diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 2022dcbd33..be2a3671ca 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -129,6 +129,19 @@ class DDPFullyShardedStrategy(DDPPlugin): ) super().setup_distributed() + def setup(self, trainer: "pl.Trainer") -> None: + self.accelerator.setup(trainer) + self.setup_optimizers(trainer) + self.setup_precision_plugin() + self._move_optimizer_state() + + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + + self.configure_ddp() + self.barrier() + self.setup_optimizers(trainer) + @contextlib.contextmanager def model_sharded_context(self) -> Generator: precision = self.precision_plugin.precision @@ -163,14 +176,6 @@ class DDPFullyShardedStrategy(DDPPlugin): # setup optimizers after fully sharded has wrapped the lightning module self.setup_optimizers(self.lightning_module.trainer) - def pre_dispatch(self, trainer: "pl.Trainer") -> None: - self._move_optimizer_state() - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - self.configure_ddp() - self.barrier() - self.setup_optimizers(trainer) - def model_to_device(self) -> None: # ensure we update the device type in the lightning module self.lightning_module.to(self.root_device) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 858d290b20..53fe65a94a 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -79,10 +79,9 @@ class HorovodPlugin(ParallelPlugin): def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() + super().setup(trainer) - def pre_dispatch(self, trainer: "pl.Trainer") -> None: - super().pre_dispatch(trainer) self._exit_stack = ExitStack() self._exit_stack.__enter__() diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 9a1ddaf9b3..d8e1097212 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -126,14 +126,6 @@ class IPUPlugin(ParallelPlugin): super().setup(trainer) - def setup_optimizers(self, trainer: "pl.Trainer") -> None: - super().setup_optimizers(trainer) - - if len(self.optimizers) > 1: - raise MisconfigurationException("IPUs currently only support one optimizer.") - - def pre_dispatch(self, trainer: "pl.Trainer") -> None: - super().pre_dispatch(trainer) model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision) self.model = model @@ -164,6 +156,12 @@ class IPUPlugin(ParallelPlugin): model = poptorch.inferenceModel(model=model, options=self.inference_opts) self.poptorch_models[RunningStage.PREDICTING] = model + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + super().setup_optimizers(trainer) + + if len(self.optimizers) > 1: + raise MisconfigurationException("IPUs currently only support one optimizer.") + @property def replication_factor(self) -> int: if not self.lightning_module or not self.poptorch_models: diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 34bb0b01f4..3d23504ea5 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -64,11 +64,6 @@ class SingleTPUPlugin(SingleDevicePlugin): super().setup(trainer) - def model_to_device(self) -> None: - self.model.to(self.root_device) - - def pre_dispatch(self, trainer: "pl.Trainer") -> None: - super().pre_dispatch(trainer) if isinstance(self.device, int): self.device = xm.xla_device(self.device) @@ -78,6 +73,9 @@ class SingleTPUPlugin(SingleDevicePlugin): self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() + def model_to_device(self) -> None: + self.model.to(self.root_device) + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 226c1cdca7..bbf545f441 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -120,8 +120,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin): self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model) - def pre_dispatch(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: "pl.Trainer") -> None: + self.start_method = "fork" + self.accelerator.setup(trainer) + self.setup_optimizers(trainer) + self.setup_precision_plugin() self._move_optimizer_state() + if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 0e48aab404..cd45a38712 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -121,6 +121,7 @@ class Strategy(ABC): self.accelerator.setup(trainer) self.setup_optimizers(trainer) self.setup_precision_plugin() + self._move_optimizer_state() def setup_precision_plugin(self) -> None: """Attaches the precision plugin to the accelerator.""" @@ -490,10 +491,6 @@ class Strategy(ABC): """Called in the training loop before anything happens for that batch.""" pass - def pre_dispatch(self, trainer: "pl.Trainer") -> None: - """Hook to do something before the training/evaluation/prediction starts.""" - self._move_optimizer_state() - def dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.precision_plugin.dispatch(trainer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 246f53d280..dc19aa206a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1111,7 +1111,6 @@ class Trainer( self._restore_modules_and_callbacks(ckpt_path) self._call_configure_sharded_model() # allow user to setup in model sharded environment - self.training_type_plugin.setup(self) # ---------------------------- # INSPECT THE CORE LOOPS @@ -1145,13 +1144,15 @@ class Trainer( self.logger_connector.reset_results() self.logger_connector.reset_metrics() + # strategy will configure model and move it to the device + self.training_type_plugin.setup(self) + # hook if self.state.fn == TrainerFn.FITTING: self._call_callback_hooks("on_fit_start") self._call_lightning_module_hook("on_fit_start") - # plugin will move model to device - self._pre_dispatch() + self._log_hyperparams() if self.training_type_plugin.restore_checkpoint_after_pre_dispatch: self._restore_modules_and_callbacks(ckpt_path) @@ -1183,10 +1184,6 @@ class Trainer( return results - def _pre_dispatch(self): - self.training_type_plugin.pre_dispatch(self) - self._log_hyperparams() - def _log_hyperparams(self) -> None: # log hyper-parameters hparams_initial = None diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 2ef234b1ff..de0030fdb9 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -28,18 +28,18 @@ def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatc dispatch is called.""" class TestPlugin(SingleDevicePlugin): - predispatched_called = False + setup_called = False - def pre_dispatch(self, trainer: "pl.Trainer") -> None: - super().pre_dispatch(trainer) - self.predispatched_called = True + def setup(self, trainer: "pl.Trainer") -> None: + super().setup(trainer) + self.setup_called = True @property def restore_checkpoint_after_pre_dispatch(self) -> bool: return restore_after_pre_dispatch def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - assert self.predispatched_called == restore_after_pre_dispatch + assert self.setup_called == restore_after_pre_dispatch return super().load_checkpoint(checkpoint_path) model = BoringModel() @@ -60,5 +60,5 @@ def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatc trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True) trainer.fit(model, ckpt_path=checkpoint_path) for func in (trainer.test, trainer.validate, trainer.predict): - plugin.predispatched_called = False + plugin.setup_called = False func(model, ckpt_path=checkpoint_path) diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index d404db964a..7af34035d9 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -436,7 +436,7 @@ def test_replication_factor(tmpdir): plugin.model = model model.trainer = trainer trainer.state.fn = TrainerFn.FITTING - trainer.training_type_plugin.pre_dispatch(trainer) + trainer.training_type_plugin.setup(trainer) trainer.state.stage = RunningStage.TRAINING assert trainer.training_type_plugin.replication_factor == 8 @@ -450,7 +450,7 @@ def test_replication_factor(tmpdir): ): trainer.state.fn = fn trainer.state.stage = stage - trainer.training_type_plugin.pre_dispatch(trainer) + trainer.training_type_plugin.setup(trainer) assert trainer.training_type_plugin.replication_factor == 7 @@ -585,7 +585,7 @@ def test_poptorch_models_at_different_stages(tmpdir): trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING - trainer.training_type_plugin.pre_dispatch(trainer) + trainer.training_type_plugin.setup(trainer) assert list(trainer.training_type_plugin.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING] for fn, stage in ( @@ -595,7 +595,7 @@ def test_poptorch_models_at_different_stages(tmpdir): ): trainer.state.fn = fn trainer.state.stage = stage - trainer.training_type_plugin.pre_dispatch(trainer) + trainer.training_type_plugin.setup(trainer) assert list(trainer.training_type_plugin.poptorch_models) == [stage] diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 43321ad2ae..5842152278 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -516,13 +516,9 @@ def _run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization): dict(name="setup", kwargs=dict(stage="fit")), dict(name="configure_sharded_model"), dict(name="Callback.on_configure_sharded_model", args=(trainer, model)), - # DeepSpeed skips initializing optimizers here as they are handled via config - *([dict(name="configure_optimizers")] if kwargs.get("strategy") != "deepspeed" else []), + dict(name="configure_optimizers"), dict(name="Callback.on_fit_start", args=(trainer, model)), dict(name="on_fit_start"), - # TODO: explore whether DeepSpeed can have the same flow for optimizers - # DeepSpeed did not find any optimizer in the config so they are loaded here - *([dict(name="configure_optimizers")] if kwargs.get("strategy") == "deepspeed" else []), dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)), dict(name="on_pretrain_routine_start"), dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)), diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index e99474efd8..58633beb01 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -110,11 +110,10 @@ def test_ddp_configure_ddp(): # test wrap the model if fitting trainer.state.fn = TrainerFn.FITTING trainer.training_type_plugin.connect(model) - trainer.training_type_plugin.setup_environment() - trainer.training_type_plugin.setup(trainer) trainer.lightning_module.trainer = trainer + trainer.training_type_plugin.setup_environment() assert isinstance(trainer.model, LightningModule) - trainer._pre_dispatch() + trainer.training_type_plugin.setup(trainer) # in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel assert isinstance(trainer.model, DistributedDataParallel) @@ -123,10 +122,10 @@ def test_ddp_configure_ddp(): strategy=ddp_plugin, ) # test do not wrap the model if trainerFN is not fitting + trainer.state.fn = TrainerFn.VALIDATING trainer.training_type_plugin.connect(model) + trainer.lightning_module.trainer = trainer trainer.training_type_plugin.setup_environment() trainer.training_type_plugin.setup(trainer) - trainer.lightning_module.trainer = trainer - trainer._pre_dispatch() # in DDPPlugin configure_ddp(), model are still LightningModule assert isinstance(trainer.model, LightningModule) diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py index 36c266343b..f62adbddc6 100644 --- a/tests/trainer/logging_/test_distributed_logging.py +++ b/tests/trainer/logging_/test_distributed_logging.py @@ -104,8 +104,8 @@ def test_first_logger_call_in_subprocess(tmpdir): """ class LoggerCallsObserver(Callback): - def on_fit_start(self, trainer, pl_module): - # this hook is executed directly before Trainer.pre_dispatch + def setup(self, trainer, pl_module, stage): + # this hook is executed after Strategy has setup the environment # logger should not write any logs until this point assert not trainer.logger.method_calls assert not os.listdir(trainer.logger.save_dir)