Standardize model attribute access in training type plugins (#11072)

This commit is contained in:
Adrian Wälchli 2021-12-15 16:37:21 +01:00 committed by GitHub
parent fde326d7e0
commit ffb1a754af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 20 additions and 20 deletions

View File

@ -273,7 +273,7 @@ class DDPPlugin(ParallelPlugin):
_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device
):
register_ddp_comm_hook(
model=self._model,
model=self.model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
@ -330,7 +330,7 @@ class DDPPlugin(ParallelPlugin):
def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = self._setup_model(LightningDistributedModule(self.model))
self.model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
def determine_ddp_device_ids(self):

View File

@ -203,7 +203,7 @@ class DDPSpawnPlugin(ParallelPlugin):
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device:
register_ddp_comm_hook(
model=self._model,
model=self.model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
@ -211,7 +211,7 @@ class DDPSpawnPlugin(ParallelPlugin):
def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = self._setup_model(LightningDistributedModule(self.model))
self.model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
def determine_ddp_device_ids(self):

View File

@ -398,9 +398,9 @@ class DeepSpeedPlugin(DDPPlugin):
# normally we set this to the batch size, but it is not available here unless the user provides it
# as part of the config
self.config.setdefault("train_micro_batch_size_per_gpu", 1)
self._model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
self.model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
self._set_deepspeed_activation_checkpointing()
return self._model, [optimizer]
return self.model, [optimizer]
def _setup_model_and_optimizer(
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None

View File

@ -65,7 +65,7 @@ class DataParallelPlugin(ParallelPlugin):
def setup(self, trainer: "pl.Trainer") -> None:
# model needs to be moved to the device before it is wrapped
self.model_to_device()
self._model = self._setup_model(LightningParallelModule(self._model))
self.model = self._setup_model(LightningParallelModule(self.model))
super().setup(trainer)
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
@ -107,7 +107,7 @@ class DataParallelPlugin(ParallelPlugin):
return self.parallel_devices[0]
def model_to_device(self) -> None:
self._model.to(self.root_device)
self.model.to(self.root_device)
def barrier(self, *args, **kwargs):
pass

View File

@ -58,7 +58,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
return unwrap_lightning_module(self._model) if self._model is not None else None
return unwrap_lightning_module(self.model) if self.model is not None else None
@property
def global_rank(self) -> int:

View File

@ -45,7 +45,7 @@ class DDPShardedPlugin(DDPPlugin):
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
self._model, optimizers = self._setup_model_and_optimizers(
self.model, optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model),
optimizers=trainer.optimizers,
)
@ -107,7 +107,7 @@ class DDPShardedPlugin(DDPPlugin):
"`DDPShardedPlugin` requires `fairscale` to be installed."
" Install it by running `pip install fairscale`."
)
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
return unwrap_lightning_module_sharded(self.model) if self.model is not None else None
def pre_backward(self, closure_loss: torch.Tensor) -> None:
pass

View File

@ -41,7 +41,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
def configure_ddp(self) -> None:
trainer = self.lightning_module.trainer
self._model, optimizers = self._setup_model_and_optimizers(
self.model, optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model),
optimizers=trainer.optimizers,
)
@ -106,7 +106,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
" Install it by running `pip install fairscale`."
)
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
return unwrap_lightning_module_sharded(self.model) if self.model is not None else None
def pre_backward(self, closure_loss: torch.Tensor) -> None:
pass

View File

@ -68,7 +68,7 @@ class SingleDevicePlugin(TrainingTypePlugin):
return self.device
def model_to_device(self) -> None:
self._model.to(self.root_device)
self.model.to(self.root_device)
def setup(self, trainer: "pl.Trainer") -> None:
self.model_to_device()

View File

@ -132,7 +132,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
set_shared_parameters(self.model.module, shared_params)
self.setup_optimizers(trainer)
self.precision_plugin.connect(self._model, None, None)
self.precision_plugin.connect(self.model, None, None)
def setup(self, trainer: "pl.Trainer") -> None:
self.start_method = "fork"

View File

@ -293,7 +293,7 @@ class TrainingTypePlugin(ABC):
@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
"""Returns the pure LightningModule without potential wrappers."""
return unwrap_lightning_module(self._model) if self._model is not None else None
return unwrap_lightning_module(self.model) if self.model is not None else None
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()

View File

@ -40,7 +40,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir):
fast_dev_run=True,
)
trainer.fit(model)
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
expected_comm_hook = default.fp16_compress_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"
@ -63,7 +63,7 @@ def test_ddp_sgd_comm_hook(tmpdir):
fast_dev_run=True,
)
trainer.fit(model)
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
expected_comm_hook = powerSGD.powerSGD_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"
@ -87,7 +87,7 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
fast_dev_run=True,
)
trainer.fit(model)
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"
@ -132,7 +132,7 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir):
sync_batchnorm=True,
)
trainer.fit(model)
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"