Standardize model attribute access in training type plugins (#11072)
This commit is contained in:
parent
fde326d7e0
commit
ffb1a754af
|
@ -273,7 +273,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device
|
||||
):
|
||||
register_ddp_comm_hook(
|
||||
model=self._model,
|
||||
model=self.model,
|
||||
ddp_comm_state=self._ddp_comm_state,
|
||||
ddp_comm_hook=self._ddp_comm_hook,
|
||||
ddp_comm_wrapper=self._ddp_comm_wrapper,
|
||||
|
@ -330,7 +330,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
|
||||
def configure_ddp(self) -> None:
|
||||
self.pre_configure_ddp()
|
||||
self._model = self._setup_model(LightningDistributedModule(self.model))
|
||||
self.model = self._setup_model(LightningDistributedModule(self.model))
|
||||
self._register_ddp_hooks()
|
||||
|
||||
def determine_ddp_device_ids(self):
|
||||
|
|
|
@ -203,7 +203,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
|
||||
if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device:
|
||||
register_ddp_comm_hook(
|
||||
model=self._model,
|
||||
model=self.model,
|
||||
ddp_comm_state=self._ddp_comm_state,
|
||||
ddp_comm_hook=self._ddp_comm_hook,
|
||||
ddp_comm_wrapper=self._ddp_comm_wrapper,
|
||||
|
@ -211,7 +211,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
|
||||
def configure_ddp(self) -> None:
|
||||
self.pre_configure_ddp()
|
||||
self._model = self._setup_model(LightningDistributedModule(self.model))
|
||||
self.model = self._setup_model(LightningDistributedModule(self.model))
|
||||
self._register_ddp_hooks()
|
||||
|
||||
def determine_ddp_device_ids(self):
|
||||
|
|
|
@ -398,9 +398,9 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
# normally we set this to the batch size, but it is not available here unless the user provides it
|
||||
# as part of the config
|
||||
self.config.setdefault("train_micro_batch_size_per_gpu", 1)
|
||||
self._model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
|
||||
self.model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
|
||||
self._set_deepspeed_activation_checkpointing()
|
||||
return self._model, [optimizer]
|
||||
return self.model, [optimizer]
|
||||
|
||||
def _setup_model_and_optimizer(
|
||||
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None
|
||||
|
|
|
@ -65,7 +65,7 @@ class DataParallelPlugin(ParallelPlugin):
|
|||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
# model needs to be moved to the device before it is wrapped
|
||||
self.model_to_device()
|
||||
self._model = self._setup_model(LightningParallelModule(self._model))
|
||||
self.model = self._setup_model(LightningParallelModule(self.model))
|
||||
super().setup(trainer)
|
||||
|
||||
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
|
||||
|
@ -107,7 +107,7 @@ class DataParallelPlugin(ParallelPlugin):
|
|||
return self.parallel_devices[0]
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
self._model.to(self.root_device)
|
||||
self.model.to(self.root_device)
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
@ -58,7 +58,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
|
|||
|
||||
@property
|
||||
def lightning_module(self) -> Optional["pl.LightningModule"]:
|
||||
return unwrap_lightning_module(self._model) if self._model is not None else None
|
||||
return unwrap_lightning_module(self.model) if self.model is not None else None
|
||||
|
||||
@property
|
||||
def global_rank(self) -> int:
|
||||
|
|
|
@ -45,7 +45,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
# For multi-node training, enabling bucketing will improve performance.
|
||||
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
|
||||
|
||||
self._model, optimizers = self._setup_model_and_optimizers(
|
||||
self.model, optimizers = self._setup_model_and_optimizers(
|
||||
model=LightningShardedDataParallel(self.model),
|
||||
optimizers=trainer.optimizers,
|
||||
)
|
||||
|
@ -107,7 +107,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
"`DDPShardedPlugin` requires `fairscale` to be installed."
|
||||
" Install it by running `pip install fairscale`."
|
||||
)
|
||||
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
|
||||
return unwrap_lightning_module_sharded(self.model) if self.model is not None else None
|
||||
|
||||
def pre_backward(self, closure_loss: torch.Tensor) -> None:
|
||||
pass
|
||||
|
|
|
@ -41,7 +41,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
|
||||
def configure_ddp(self) -> None:
|
||||
trainer = self.lightning_module.trainer
|
||||
self._model, optimizers = self._setup_model_and_optimizers(
|
||||
self.model, optimizers = self._setup_model_and_optimizers(
|
||||
model=LightningShardedDataParallel(self.model),
|
||||
optimizers=trainer.optimizers,
|
||||
)
|
||||
|
@ -106,7 +106,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
|
||||
" Install it by running `pip install fairscale`."
|
||||
)
|
||||
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
|
||||
return unwrap_lightning_module_sharded(self.model) if self.model is not None else None
|
||||
|
||||
def pre_backward(self, closure_loss: torch.Tensor) -> None:
|
||||
pass
|
||||
|
|
|
@ -68,7 +68,7 @@ class SingleDevicePlugin(TrainingTypePlugin):
|
|||
return self.device
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
self._model.to(self.root_device)
|
||||
self.model.to(self.root_device)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.model_to_device()
|
||||
|
|
|
@ -132,7 +132,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
set_shared_parameters(self.model.module, shared_params)
|
||||
|
||||
self.setup_optimizers(trainer)
|
||||
self.precision_plugin.connect(self._model, None, None)
|
||||
self.precision_plugin.connect(self.model, None, None)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.start_method = "fork"
|
||||
|
|
|
@ -293,7 +293,7 @@ class TrainingTypePlugin(ABC):
|
|||
@property
|
||||
def lightning_module(self) -> Optional["pl.LightningModule"]:
|
||||
"""Returns the pure LightningModule without potential wrappers."""
|
||||
return unwrap_lightning_module(self._model) if self._model is not None else None
|
||||
return unwrap_lightning_module(self.model) if self.model is not None else None
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -40,7 +40,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir):
|
|||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
|
||||
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
|
||||
expected_comm_hook = default.fp16_compress_hook.__qualname__
|
||||
assert trainer_comm_hook == expected_comm_hook
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
@ -63,7 +63,7 @@ def test_ddp_sgd_comm_hook(tmpdir):
|
|||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
|
||||
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
|
||||
expected_comm_hook = powerSGD.powerSGD_hook.__qualname__
|
||||
assert trainer_comm_hook == expected_comm_hook
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
@ -87,7 +87,7 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
|
|||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
|
||||
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
|
||||
expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__
|
||||
assert trainer_comm_hook == expected_comm_hook
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
@ -132,7 +132,7 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir):
|
|||
sync_batchnorm=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
|
||||
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
|
||||
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
|
||||
assert trainer_comm_hook == expected_comm_hook
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
|
Loading…
Reference in New Issue