3/n Simplify spawn plugins: Merge `pre_dispatch` and `setup` logic (#11137)
This commit is contained in:
parent
2e47e2f4ae
commit
f5c2881b68
|
@ -112,6 +112,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
|
||||
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts
|
||||
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
|
||||
* Removed the `TrainingTypePlugin.pre_dispatch()` method and merged it with `TrainingTypePlugin.setup()` ([#11137](https://github.com/PyTorchLightning/pytorch-lightning/pull/11137))
|
||||
|
||||
|
||||
- Changed profiler to index and display the names of the hooks with a new pattern [<base class>]<class>.<hook name> ([#11026](https://github.com/PyTorchLightning/pytorch-lightning/pull/11026))
|
||||
|
|
|
@ -151,6 +151,24 @@ class DDPPlugin(ParallelPlugin):
|
|||
self.setup_distributed()
|
||||
super().setup_environment()
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
super().setup(trainer)
|
||||
# share ddp pids to all processes
|
||||
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
|
||||
if self._should_run_deadlock_detection():
|
||||
self._share_information_to_prevent_deadlock()
|
||||
|
||||
# move the model to the correct device
|
||||
self.model_to_device()
|
||||
|
||||
if self.sync_batchnorm:
|
||||
self.model = self.configure_sync_batchnorm(self.model)
|
||||
|
||||
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
|
||||
trainer_fn = self.lightning_module.trainer.state.fn
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
self.configure_ddp()
|
||||
|
||||
def _setup_model(self, model: Module) -> DistributedDataParallel:
|
||||
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
|
||||
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
|
||||
|
@ -341,24 +359,6 @@ class DDPPlugin(ParallelPlugin):
|
|||
return None
|
||||
return [self.root_device.index]
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
super().pre_dispatch(trainer)
|
||||
# share ddp pids to all processes
|
||||
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
|
||||
if self._should_run_deadlock_detection():
|
||||
self._share_information_to_prevent_deadlock()
|
||||
|
||||
# move the model to the correct device
|
||||
self.model_to_device()
|
||||
|
||||
if self.sync_batchnorm:
|
||||
self.model = self.configure_sync_batchnorm(self.model)
|
||||
|
||||
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
|
||||
trainer_fn = self.lightning_module.trainer.state.fn
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
self.configure_ddp()
|
||||
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
if not distributed_available():
|
||||
return
|
||||
|
|
|
@ -120,6 +120,17 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
|
||||
super().setup(trainer)
|
||||
|
||||
# move the model to the correct device
|
||||
self.model_to_device()
|
||||
|
||||
if self.sync_batchnorm:
|
||||
self.model = self.configure_sync_batchnorm(self.model)
|
||||
|
||||
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
|
||||
trainer_fn = self.lightning_module.trainer.state.fn
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
self.configure_ddp()
|
||||
|
||||
def _setup_model(self, model: Module) -> DistributedDataParallel:
|
||||
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
|
||||
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
|
||||
|
@ -170,20 +181,6 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
|
||||
)
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
super().pre_dispatch(trainer)
|
||||
|
||||
# move the model to the correct device
|
||||
self.model_to_device()
|
||||
|
||||
if self.sync_batchnorm:
|
||||
self.model = self.configure_sync_batchnorm(self.model)
|
||||
|
||||
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
|
||||
trainer_fn = self.lightning_module.trainer.state.fn
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
self.configure_ddp()
|
||||
|
||||
def pre_configure_ddp(self):
|
||||
# if unset, default `find_unused_parameters` `True`
|
||||
# Many models require setting this parameter to True, as there are corner cases
|
||||
|
|
|
@ -346,6 +346,14 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
self._format_config()
|
||||
self._config_initialized = True
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.accelerator.setup(trainer)
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
self._move_optimizer_state()
|
||||
self.init_deepspeed()
|
||||
self.barrier()
|
||||
|
||||
def _init_deepspeed_distributed(self) -> None:
|
||||
if platform.system() != "Windows":
|
||||
# do not set env variables on windows, allow deepspeed to control setup
|
||||
|
@ -368,11 +376,6 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
def restore_checkpoint_after_pre_dispatch(self) -> bool:
|
||||
return True
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
self._move_optimizer_state()
|
||||
self.init_deepspeed()
|
||||
self.barrier()
|
||||
|
||||
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
|
||||
"""Setup a model and multiple optimizers together.
|
||||
|
||||
|
|
|
@ -129,6 +129,19 @@ class DDPFullyShardedStrategy(DDPPlugin):
|
|||
)
|
||||
super().setup_distributed()
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.accelerator.setup(trainer)
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
self._move_optimizer_state()
|
||||
|
||||
if self.sync_batchnorm:
|
||||
self.model = self.configure_sync_batchnorm(self.model)
|
||||
|
||||
self.configure_ddp()
|
||||
self.barrier()
|
||||
self.setup_optimizers(trainer)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def model_sharded_context(self) -> Generator:
|
||||
precision = self.precision_plugin.precision
|
||||
|
@ -163,14 +176,6 @@ class DDPFullyShardedStrategy(DDPPlugin):
|
|||
# setup optimizers after fully sharded has wrapped the lightning module
|
||||
self.setup_optimizers(self.lightning_module.trainer)
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
self._move_optimizer_state()
|
||||
if self.sync_batchnorm:
|
||||
self.model = self.configure_sync_batchnorm(self.model)
|
||||
self.configure_ddp()
|
||||
self.barrier()
|
||||
self.setup_optimizers(trainer)
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
# ensure we update the device type in the lightning module
|
||||
self.lightning_module.to(self.root_device)
|
||||
|
|
|
@ -79,10 +79,9 @@ class HorovodPlugin(ParallelPlugin):
|
|||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.model_to_device()
|
||||
|
||||
super().setup(trainer)
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
super().pre_dispatch(trainer)
|
||||
self._exit_stack = ExitStack()
|
||||
self._exit_stack.__enter__()
|
||||
|
||||
|
|
|
@ -126,14 +126,6 @@ class IPUPlugin(ParallelPlugin):
|
|||
|
||||
super().setup(trainer)
|
||||
|
||||
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
|
||||
super().setup_optimizers(trainer)
|
||||
|
||||
if len(self.optimizers) > 1:
|
||||
raise MisconfigurationException("IPUs currently only support one optimizer.")
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
super().pre_dispatch(trainer)
|
||||
model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision)
|
||||
self.model = model
|
||||
|
||||
|
@ -164,6 +156,12 @@ class IPUPlugin(ParallelPlugin):
|
|||
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
|
||||
self.poptorch_models[RunningStage.PREDICTING] = model
|
||||
|
||||
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
|
||||
super().setup_optimizers(trainer)
|
||||
|
||||
if len(self.optimizers) > 1:
|
||||
raise MisconfigurationException("IPUs currently only support one optimizer.")
|
||||
|
||||
@property
|
||||
def replication_factor(self) -> int:
|
||||
if not self.lightning_module or not self.poptorch_models:
|
||||
|
|
|
@ -64,11 +64,6 @@ class SingleTPUPlugin(SingleDevicePlugin):
|
|||
|
||||
super().setup(trainer)
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
self.model.to(self.root_device)
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
super().pre_dispatch(trainer)
|
||||
if isinstance(self.device, int):
|
||||
self.device = xm.xla_device(self.device)
|
||||
|
||||
|
@ -78,6 +73,9 @@ class SingleTPUPlugin(SingleDevicePlugin):
|
|||
self.tpu_local_core_rank = xm.get_local_ordinal()
|
||||
self.tpu_global_core_rank = xm.get_ordinal()
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
self.model.to(self.root_device)
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
|
|
|
@ -120,8 +120,13 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
|
||||
return super().connect(model)
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.start_method = "fork"
|
||||
self.accelerator.setup(trainer)
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
self._move_optimizer_state()
|
||||
|
||||
if self.debug:
|
||||
os.environ["PT_XLA_DEBUG"] = str(1)
|
||||
|
||||
|
|
|
@ -121,6 +121,7 @@ class Strategy(ABC):
|
|||
self.accelerator.setup(trainer)
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
self._move_optimizer_state()
|
||||
|
||||
def setup_precision_plugin(self) -> None:
|
||||
"""Attaches the precision plugin to the accelerator."""
|
||||
|
@ -490,10 +491,6 @@ class Strategy(ABC):
|
|||
"""Called in the training loop before anything happens for that batch."""
|
||||
pass
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
self._move_optimizer_state()
|
||||
|
||||
def dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
self.precision_plugin.dispatch(trainer)
|
||||
|
|
|
@ -1111,7 +1111,6 @@ class Trainer(
|
|||
self._restore_modules_and_callbacks(ckpt_path)
|
||||
|
||||
self._call_configure_sharded_model() # allow user to setup in model sharded environment
|
||||
self.training_type_plugin.setup(self)
|
||||
|
||||
# ----------------------------
|
||||
# INSPECT THE CORE LOOPS
|
||||
|
@ -1145,13 +1144,15 @@ class Trainer(
|
|||
self.logger_connector.reset_results()
|
||||
self.logger_connector.reset_metrics()
|
||||
|
||||
# strategy will configure model and move it to the device
|
||||
self.training_type_plugin.setup(self)
|
||||
|
||||
# hook
|
||||
if self.state.fn == TrainerFn.FITTING:
|
||||
self._call_callback_hooks("on_fit_start")
|
||||
self._call_lightning_module_hook("on_fit_start")
|
||||
|
||||
# plugin will move model to device
|
||||
self._pre_dispatch()
|
||||
self._log_hyperparams()
|
||||
|
||||
if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
|
||||
self._restore_modules_and_callbacks(ckpt_path)
|
||||
|
@ -1183,10 +1184,6 @@ class Trainer(
|
|||
|
||||
return results
|
||||
|
||||
def _pre_dispatch(self):
|
||||
self.training_type_plugin.pre_dispatch(self)
|
||||
self._log_hyperparams()
|
||||
|
||||
def _log_hyperparams(self) -> None:
|
||||
# log hyper-parameters
|
||||
hparams_initial = None
|
||||
|
|
|
@ -28,18 +28,18 @@ def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatc
|
|||
dispatch is called."""
|
||||
|
||||
class TestPlugin(SingleDevicePlugin):
|
||||
predispatched_called = False
|
||||
setup_called = False
|
||||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
super().pre_dispatch(trainer)
|
||||
self.predispatched_called = True
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
super().setup(trainer)
|
||||
self.setup_called = True
|
||||
|
||||
@property
|
||||
def restore_checkpoint_after_pre_dispatch(self) -> bool:
|
||||
return restore_after_pre_dispatch
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
assert self.predispatched_called == restore_after_pre_dispatch
|
||||
assert self.setup_called == restore_after_pre_dispatch
|
||||
return super().load_checkpoint(checkpoint_path)
|
||||
|
||||
model = BoringModel()
|
||||
|
@ -60,5 +60,5 @@ def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatc
|
|||
trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True)
|
||||
trainer.fit(model, ckpt_path=checkpoint_path)
|
||||
for func in (trainer.test, trainer.validate, trainer.predict):
|
||||
plugin.predispatched_called = False
|
||||
plugin.setup_called = False
|
||||
func(model, ckpt_path=checkpoint_path)
|
||||
|
|
|
@ -436,7 +436,7 @@ def test_replication_factor(tmpdir):
|
|||
plugin.model = model
|
||||
model.trainer = trainer
|
||||
trainer.state.fn = TrainerFn.FITTING
|
||||
trainer.training_type_plugin.pre_dispatch(trainer)
|
||||
trainer.training_type_plugin.setup(trainer)
|
||||
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
assert trainer.training_type_plugin.replication_factor == 8
|
||||
|
@ -450,7 +450,7 @@ def test_replication_factor(tmpdir):
|
|||
):
|
||||
trainer.state.fn = fn
|
||||
trainer.state.stage = stage
|
||||
trainer.training_type_plugin.pre_dispatch(trainer)
|
||||
trainer.training_type_plugin.setup(trainer)
|
||||
assert trainer.training_type_plugin.replication_factor == 7
|
||||
|
||||
|
||||
|
@ -585,7 +585,7 @@ def test_poptorch_models_at_different_stages(tmpdir):
|
|||
|
||||
trainer.optimizers = model.configure_optimizers()[0]
|
||||
trainer.state.fn = TrainerFn.FITTING
|
||||
trainer.training_type_plugin.pre_dispatch(trainer)
|
||||
trainer.training_type_plugin.setup(trainer)
|
||||
assert list(trainer.training_type_plugin.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING]
|
||||
|
||||
for fn, stage in (
|
||||
|
@ -595,7 +595,7 @@ def test_poptorch_models_at_different_stages(tmpdir):
|
|||
):
|
||||
trainer.state.fn = fn
|
||||
trainer.state.stage = stage
|
||||
trainer.training_type_plugin.pre_dispatch(trainer)
|
||||
trainer.training_type_plugin.setup(trainer)
|
||||
assert list(trainer.training_type_plugin.poptorch_models) == [stage]
|
||||
|
||||
|
||||
|
|
|
@ -516,13 +516,9 @@ def _run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization):
|
|||
dict(name="setup", kwargs=dict(stage="fit")),
|
||||
dict(name="configure_sharded_model"),
|
||||
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
|
||||
# DeepSpeed skips initializing optimizers here as they are handled via config
|
||||
*([dict(name="configure_optimizers")] if kwargs.get("strategy") != "deepspeed" else []),
|
||||
dict(name="configure_optimizers"),
|
||||
dict(name="Callback.on_fit_start", args=(trainer, model)),
|
||||
dict(name="on_fit_start"),
|
||||
# TODO: explore whether DeepSpeed can have the same flow for optimizers
|
||||
# DeepSpeed did not find any optimizer in the config so they are loaded here
|
||||
*([dict(name="configure_optimizers")] if kwargs.get("strategy") == "deepspeed" else []),
|
||||
dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)),
|
||||
dict(name="on_pretrain_routine_start"),
|
||||
dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)),
|
||||
|
|
|
@ -110,11 +110,10 @@ def test_ddp_configure_ddp():
|
|||
# test wrap the model if fitting
|
||||
trainer.state.fn = TrainerFn.FITTING
|
||||
trainer.training_type_plugin.connect(model)
|
||||
trainer.training_type_plugin.setup_environment()
|
||||
trainer.training_type_plugin.setup(trainer)
|
||||
trainer.lightning_module.trainer = trainer
|
||||
trainer.training_type_plugin.setup_environment()
|
||||
assert isinstance(trainer.model, LightningModule)
|
||||
trainer._pre_dispatch()
|
||||
trainer.training_type_plugin.setup(trainer)
|
||||
# in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel
|
||||
assert isinstance(trainer.model, DistributedDataParallel)
|
||||
|
||||
|
@ -123,10 +122,10 @@ def test_ddp_configure_ddp():
|
|||
strategy=ddp_plugin,
|
||||
)
|
||||
# test do not wrap the model if trainerFN is not fitting
|
||||
trainer.state.fn = TrainerFn.VALIDATING
|
||||
trainer.training_type_plugin.connect(model)
|
||||
trainer.lightning_module.trainer = trainer
|
||||
trainer.training_type_plugin.setup_environment()
|
||||
trainer.training_type_plugin.setup(trainer)
|
||||
trainer.lightning_module.trainer = trainer
|
||||
trainer._pre_dispatch()
|
||||
# in DDPPlugin configure_ddp(), model are still LightningModule
|
||||
assert isinstance(trainer.model, LightningModule)
|
||||
|
|
|
@ -104,8 +104,8 @@ def test_first_logger_call_in_subprocess(tmpdir):
|
|||
"""
|
||||
|
||||
class LoggerCallsObserver(Callback):
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
# this hook is executed directly before Trainer.pre_dispatch
|
||||
def setup(self, trainer, pl_module, stage):
|
||||
# this hook is executed after Strategy has setup the environment
|
||||
# logger should not write any logs until this point
|
||||
assert not trainer.logger.method_calls
|
||||
assert not os.listdir(trainer.logger.save_dir)
|
||||
|
|
Loading…
Reference in New Issue