From ffb1a754af5301608c71955c12468216e920af51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 15 Dec 2021 16:37:21 +0100 Subject: [PATCH] Standardize model attribute access in training type plugins (#11072) --- pytorch_lightning/plugins/training_type/ddp.py | 4 ++-- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/deepspeed.py | 4 ++-- pytorch_lightning/plugins/training_type/dp.py | 4 ++-- pytorch_lightning/plugins/training_type/parallel.py | 2 +- pytorch_lightning/plugins/training_type/sharded.py | 4 ++-- pytorch_lightning/plugins/training_type/sharded_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/single_device.py | 2 +- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- .../plugins/training_type/training_type_plugin.py | 2 +- tests/plugins/test_ddp_plugin_with_comm_hook.py | 8 ++++---- 11 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 62d198536b..829735b0e0 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -273,7 +273,7 @@ class DDPPlugin(ParallelPlugin): _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device ): register_ddp_comm_hook( - model=self._model, + model=self.model, ddp_comm_state=self._ddp_comm_state, ddp_comm_hook=self._ddp_comm_hook, ddp_comm_wrapper=self._ddp_comm_wrapper, @@ -330,7 +330,7 @@ class DDPPlugin(ParallelPlugin): def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = self._setup_model(LightningDistributedModule(self.model)) + self.model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ee7a1efa28..975f4ba435 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -203,7 +203,7 @@ class DDPSpawnPlugin(ParallelPlugin): # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device: register_ddp_comm_hook( - model=self._model, + model=self.model, ddp_comm_state=self._ddp_comm_state, ddp_comm_hook=self._ddp_comm_hook, ddp_comm_wrapper=self._ddp_comm_wrapper, @@ -211,7 +211,7 @@ class DDPSpawnPlugin(ParallelPlugin): def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = self._setup_model(LightningDistributedModule(self.model)) + self.model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index cc9cd4937c..f30d15d495 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -398,9 +398,9 @@ class DeepSpeedPlugin(DDPPlugin): # normally we set this to the batch size, but it is not available here unless the user provides it # as part of the config self.config.setdefault("train_micro_batch_size_per_gpu", 1) - self._model, optimizer = self._setup_model_and_optimizer(model, optimizers[0]) + self.model, optimizer = self._setup_model_and_optimizer(model, optimizers[0]) self._set_deepspeed_activation_checkpointing() - return self._model, [optimizer] + return self.model, [optimizer] def _setup_model_and_optimizer( self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 3016ee7462..69ba2fed86 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -65,7 +65,7 @@ class DataParallelPlugin(ParallelPlugin): def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() - self._model = self._setup_model(LightningParallelModule(self._model)) + self.model = self._setup_model(LightningParallelModule(self.model)) super().setup(trainer) def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: @@ -107,7 +107,7 @@ class DataParallelPlugin(ParallelPlugin): return self.parallel_devices[0] def model_to_device(self) -> None: - self._model.to(self.root_device) + self.model.to(self.root_device) def barrier(self, *args, **kwargs): pass diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 2dc2a95f03..293e52170d 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -58,7 +58,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC): @property def lightning_module(self) -> Optional["pl.LightningModule"]: - return unwrap_lightning_module(self._model) if self._model is not None else None + return unwrap_lightning_module(self.model) if self.model is not None else None @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 280d38bc83..e0486d8635 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -45,7 +45,7 @@ class DDPShardedPlugin(DDPPlugin): # For multi-node training, enabling bucketing will improve performance. self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 - self._model, optimizers = self._setup_model_and_optimizers( + self.model, optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=trainer.optimizers, ) @@ -107,7 +107,7 @@ class DDPShardedPlugin(DDPPlugin): "`DDPShardedPlugin` requires `fairscale` to be installed." " Install it by running `pip install fairscale`." ) - return unwrap_lightning_module_sharded(self._model) if self._model is not None else None + return unwrap_lightning_module_sharded(self.model) if self.model is not None else None def pre_backward(self, closure_loss: torch.Tensor) -> None: pass diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 951a0be78e..a5607d6f19 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -41,7 +41,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): def configure_ddp(self) -> None: trainer = self.lightning_module.trainer - self._model, optimizers = self._setup_model_and_optimizers( + self.model, optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=trainer.optimizers, ) @@ -106,7 +106,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): "`DDPSpawnShardedPlugin` requires `fairscale` to be installed." " Install it by running `pip install fairscale`." ) - return unwrap_lightning_module_sharded(self._model) if self._model is not None else None + return unwrap_lightning_module_sharded(self.model) if self.model is not None else None def pre_backward(self, closure_loss: torch.Tensor) -> None: pass diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 9dde35a589..0159e86412 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -68,7 +68,7 @@ class SingleDevicePlugin(TrainingTypePlugin): return self.device def model_to_device(self) -> None: - self._model.to(self.root_device) + self.model.to(self.root_device) def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4afb61705c..013b734597 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -132,7 +132,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): set_shared_parameters(self.model.module, shared_params) self.setup_optimizers(trainer) - self.precision_plugin.connect(self._model, None, None) + self.precision_plugin.connect(self.model, None, None) def setup(self, trainer: "pl.Trainer") -> None: self.start_method = "fork" diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 171ce23f2f..0c7e1f8410 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -293,7 +293,7 @@ class TrainingTypePlugin(ABC): @property def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" - return unwrap_lightning_module(self._model) if self._model is not None else None + return unwrap_lightning_module(self.model) if self.model is not None else None def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 7ee46fe0c5..b67988b3ef 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -40,7 +40,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook expected_comm_hook = default.fp16_compress_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -63,7 +63,7 @@ def test_ddp_sgd_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook expected_comm_hook = powerSGD.powerSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -87,7 +87,7 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -132,7 +132,7 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir): sync_batchnorm=True, ) trainer.fit(model) - trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}"