3/n Simplify spawn plugins: Merge `pre_dispatch` and `setup` logic (#11137)

This commit is contained in:
Adrian Wälchli 2021-12-20 17:41:22 +01:00 committed by GitHub
parent 2e47e2f4ae
commit f5c2881b68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 89 additions and 94 deletions

View File

@ -112,6 +112,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
* Removed the `TrainingTypePlugin.pre_dispatch()` method and merged it with `TrainingTypePlugin.setup()` ([#11137](https://github.com/PyTorchLightning/pytorch-lightning/pull/11137))
- Changed profiler to index and display the names of the hooks with a new pattern [<base class>]<class>.<hook name> ([#11026](https://github.com/PyTorchLightning/pytorch-lightning/pull/11026))

View File

@ -151,6 +151,24 @@ class DDPPlugin(ParallelPlugin):
self.setup_distributed()
super().setup_environment()
def setup(self, trainer: "pl.Trainer") -> None:
super().setup(trainer)
# share ddp pids to all processes
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()
# move the model to the correct device
self.model_to_device()
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
@ -341,24 +359,6 @@ class DDPPlugin(ParallelPlugin):
return None
return [self.root_device.index]
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
# share ddp pids to all processes
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()
# move the model to the correct device
self.model_to_device()
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()
def barrier(self, *args, **kwargs) -> None:
if not distributed_available():
return

View File

@ -120,6 +120,17 @@ class DDPSpawnPlugin(ParallelPlugin):
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
super().setup(trainer)
# move the model to the correct device
self.model_to_device()
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
@ -170,20 +181,6 @@ class DDPSpawnPlugin(ParallelPlugin):
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
)
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
# move the model to the correct device
self.model_to_device()
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()
def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
# Many models require setting this parameter to True, as there are corner cases

View File

@ -346,6 +346,14 @@ class DeepSpeedPlugin(DDPPlugin):
self._format_config()
self._config_initialized = True
def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
self.init_deepspeed()
self.barrier()
def _init_deepspeed_distributed(self) -> None:
if platform.system() != "Windows":
# do not set env variables on windows, allow deepspeed to control setup
@ -368,11 +376,6 @@ class DeepSpeedPlugin(DDPPlugin):
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return True
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
self._move_optimizer_state()
self.init_deepspeed()
self.barrier()
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.

View File

@ -129,6 +129,19 @@ class DDPFullyShardedStrategy(DDPPlugin):
)
super().setup_distributed()
def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
self.configure_ddp()
self.barrier()
self.setup_optimizers(trainer)
@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
precision = self.precision_plugin.precision
@ -163,14 +176,6 @@ class DDPFullyShardedStrategy(DDPPlugin):
# setup optimizers after fully sharded has wrapped the lightning module
self.setup_optimizers(self.lightning_module.trainer)
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
self._move_optimizer_state()
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
self.configure_ddp()
self.barrier()
self.setup_optimizers(trainer)
def model_to_device(self) -> None:
# ensure we update the device type in the lightning module
self.lightning_module.to(self.root_device)

View File

@ -79,10 +79,9 @@ class HorovodPlugin(ParallelPlugin):
def setup(self, trainer: "pl.Trainer") -> None:
self.model_to_device()
super().setup(trainer)
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
self._exit_stack = ExitStack()
self._exit_stack.__enter__()

View File

@ -126,14 +126,6 @@ class IPUPlugin(ParallelPlugin):
super().setup(trainer)
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
super().setup_optimizers(trainer)
if len(self.optimizers) > 1:
raise MisconfigurationException("IPUs currently only support one optimizer.")
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision)
self.model = model
@ -164,6 +156,12 @@ class IPUPlugin(ParallelPlugin):
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
self.poptorch_models[RunningStage.PREDICTING] = model
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
super().setup_optimizers(trainer)
if len(self.optimizers) > 1:
raise MisconfigurationException("IPUs currently only support one optimizer.")
@property
def replication_factor(self) -> int:
if not self.lightning_module or not self.poptorch_models:

View File

@ -64,11 +64,6 @@ class SingleTPUPlugin(SingleDevicePlugin):
super().setup(trainer)
def model_to_device(self) -> None:
self.model.to(self.root_device)
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)
@ -78,6 +73,9 @@ class SingleTPUPlugin(SingleDevicePlugin):
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
def model_to_device(self) -> None:
self.model.to(self.root_device)
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

View File

@ -120,8 +120,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
return super().connect(model)
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
self.start_method = "fork"
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

View File

@ -121,6 +121,7 @@ class Strategy(ABC):
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator."""
@ -490,10 +491,6 @@ class Strategy(ABC):
"""Called in the training loop before anything happens for that batch."""
pass
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self._move_optimizer_state()
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.precision_plugin.dispatch(trainer)

View File

@ -1111,7 +1111,6 @@ class Trainer(
self._restore_modules_and_callbacks(ckpt_path)
self._call_configure_sharded_model() # allow user to setup in model sharded environment
self.training_type_plugin.setup(self)
# ----------------------------
# INSPECT THE CORE LOOPS
@ -1145,13 +1144,15 @@ class Trainer(
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()
# strategy will configure model and move it to the device
self.training_type_plugin.setup(self)
# hook
if self.state.fn == TrainerFn.FITTING:
self._call_callback_hooks("on_fit_start")
self._call_lightning_module_hook("on_fit_start")
# plugin will move model to device
self._pre_dispatch()
self._log_hyperparams()
if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
self._restore_modules_and_callbacks(ckpt_path)
@ -1183,10 +1184,6 @@ class Trainer(
return results
def _pre_dispatch(self):
self.training_type_plugin.pre_dispatch(self)
self._log_hyperparams()
def _log_hyperparams(self) -> None:
# log hyper-parameters
hparams_initial = None

View File

@ -28,18 +28,18 @@ def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatc
dispatch is called."""
class TestPlugin(SingleDevicePlugin):
predispatched_called = False
setup_called = False
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
self.predispatched_called = True
def setup(self, trainer: "pl.Trainer") -> None:
super().setup(trainer)
self.setup_called = True
@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return restore_after_pre_dispatch
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
assert self.predispatched_called == restore_after_pre_dispatch
assert self.setup_called == restore_after_pre_dispatch
return super().load_checkpoint(checkpoint_path)
model = BoringModel()
@ -60,5 +60,5 @@ def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatc
trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True)
trainer.fit(model, ckpt_path=checkpoint_path)
for func in (trainer.test, trainer.validate, trainer.predict):
plugin.predispatched_called = False
plugin.setup_called = False
func(model, ckpt_path=checkpoint_path)

View File

@ -436,7 +436,7 @@ def test_replication_factor(tmpdir):
plugin.model = model
model.trainer = trainer
trainer.state.fn = TrainerFn.FITTING
trainer.training_type_plugin.pre_dispatch(trainer)
trainer.training_type_plugin.setup(trainer)
trainer.state.stage = RunningStage.TRAINING
assert trainer.training_type_plugin.replication_factor == 8
@ -450,7 +450,7 @@ def test_replication_factor(tmpdir):
):
trainer.state.fn = fn
trainer.state.stage = stage
trainer.training_type_plugin.pre_dispatch(trainer)
trainer.training_type_plugin.setup(trainer)
assert trainer.training_type_plugin.replication_factor == 7
@ -585,7 +585,7 @@ def test_poptorch_models_at_different_stages(tmpdir):
trainer.optimizers = model.configure_optimizers()[0]
trainer.state.fn = TrainerFn.FITTING
trainer.training_type_plugin.pre_dispatch(trainer)
trainer.training_type_plugin.setup(trainer)
assert list(trainer.training_type_plugin.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING]
for fn, stage in (
@ -595,7 +595,7 @@ def test_poptorch_models_at_different_stages(tmpdir):
):
trainer.state.fn = fn
trainer.state.stage = stage
trainer.training_type_plugin.pre_dispatch(trainer)
trainer.training_type_plugin.setup(trainer)
assert list(trainer.training_type_plugin.poptorch_models) == [stage]

View File

@ -516,13 +516,9 @@ def _run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization):
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="configure_sharded_model"),
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
# DeepSpeed skips initializing optimizers here as they are handled via config
*([dict(name="configure_optimizers")] if kwargs.get("strategy") != "deepspeed" else []),
dict(name="configure_optimizers"),
dict(name="Callback.on_fit_start", args=(trainer, model)),
dict(name="on_fit_start"),
# TODO: explore whether DeepSpeed can have the same flow for optimizers
# DeepSpeed did not find any optimizer in the config so they are loaded here
*([dict(name="configure_optimizers")] if kwargs.get("strategy") == "deepspeed" else []),
dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)),
dict(name="on_pretrain_routine_start"),
dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)),

View File

@ -110,11 +110,10 @@ def test_ddp_configure_ddp():
# test wrap the model if fitting
trainer.state.fn = TrainerFn.FITTING
trainer.training_type_plugin.connect(model)
trainer.training_type_plugin.setup_environment()
trainer.training_type_plugin.setup(trainer)
trainer.lightning_module.trainer = trainer
trainer.training_type_plugin.setup_environment()
assert isinstance(trainer.model, LightningModule)
trainer._pre_dispatch()
trainer.training_type_plugin.setup(trainer)
# in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel
assert isinstance(trainer.model, DistributedDataParallel)
@ -123,10 +122,10 @@ def test_ddp_configure_ddp():
strategy=ddp_plugin,
)
# test do not wrap the model if trainerFN is not fitting
trainer.state.fn = TrainerFn.VALIDATING
trainer.training_type_plugin.connect(model)
trainer.lightning_module.trainer = trainer
trainer.training_type_plugin.setup_environment()
trainer.training_type_plugin.setup(trainer)
trainer.lightning_module.trainer = trainer
trainer._pre_dispatch()
# in DDPPlugin configure_ddp(), model are still LightningModule
assert isinstance(trainer.model, LightningModule)

View File

@ -104,8 +104,8 @@ def test_first_logger_call_in_subprocess(tmpdir):
"""
class LoggerCallsObserver(Callback):
def on_fit_start(self, trainer, pl_module):
# this hook is executed directly before Trainer.pre_dispatch
def setup(self, trainer, pl_module, stage):
# this hook is executed after Strategy has setup the environment
# logger should not write any logs until this point
assert not trainer.logger.method_calls
assert not os.listdir(trainer.logger.save_dir)