[IPU] Add hooks for IPU lifecycle 4/5 (#7864)
This commit is contained in:
parent
ea71cf4a5f
commit
41be61c6f2
|
@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))
|
||||
|
||||
|
||||
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
|
||||
|
|
|
@ -179,10 +179,6 @@ class Accelerator:
|
|||
|
||||
return move_data_to_device(batch, device)
|
||||
|
||||
def on_train_start(self) -> None:
|
||||
"""Hook to do something upon the training start"""
|
||||
pass
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
step_kwargs: Dict[str, Union[Any, int]],
|
||||
|
@ -348,14 +344,6 @@ class Accelerator:
|
|||
model=self.model,
|
||||
)
|
||||
|
||||
def on_train_epoch_end(self) -> None:
|
||||
"""Hook to do something on the end of an training epoch."""
|
||||
pass
|
||||
|
||||
def on_train_end(self) -> None:
|
||||
"""Hook to do something at the end of the training"""
|
||||
pass
|
||||
|
||||
def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
|
||||
"""
|
||||
Creates optimizers and schedulers
|
||||
|
@ -563,3 +551,45 @@ class Accelerator:
|
|||
|
||||
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
|
||||
return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step)
|
||||
|
||||
def on_train_epoch_end(self) -> None:
|
||||
"""Hook to do something on the end of an training epoch."""
|
||||
pass
|
||||
|
||||
def on_train_start(self) -> None:
|
||||
"""Called when train begins."""
|
||||
return self.training_type_plugin.on_train_start()
|
||||
|
||||
def on_validation_start(self) -> None:
|
||||
"""Called when validation begins."""
|
||||
return self.training_type_plugin.on_validation_start()
|
||||
|
||||
def on_test_start(self) -> None:
|
||||
"""Called when test begins."""
|
||||
return self.training_type_plugin.on_test_start()
|
||||
|
||||
def on_predict_start(self) -> None:
|
||||
"""Called when predict begins."""
|
||||
return self.training_type_plugin.on_predict_start()
|
||||
|
||||
def on_validation_end(self) -> None:
|
||||
"""Called when validation ends."""
|
||||
return self.training_type_plugin.on_validation_end()
|
||||
|
||||
def on_test_end(self) -> None:
|
||||
"""Called when test end."""
|
||||
return self.training_type_plugin.on_test_end()
|
||||
|
||||
def on_predict_end(self) -> None:
|
||||
"""Called when predict ends."""
|
||||
return self.training_type_plugin.on_predict_end()
|
||||
|
||||
def on_train_end(self) -> None:
|
||||
"""Called when train ends."""
|
||||
return self.training_type_plugin.on_train_end()
|
||||
|
||||
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
"""
|
||||
Called in the training loop before anything happens for that batch.
|
||||
"""
|
||||
return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx)
|
||||
|
|
|
@ -330,3 +330,41 @@ class TrainingTypePlugin(Plugin, ABC):
|
|||
def should_rank_save_checkpoint(self) -> bool:
|
||||
"""Returns whether the checkpoint should be saved (rank based)"""
|
||||
return self.is_global_zero
|
||||
|
||||
def on_train_start(self) -> None:
|
||||
"""Called when train begins."""
|
||||
pass
|
||||
|
||||
def on_validation_start(self) -> None:
|
||||
"""Called when validation begins."""
|
||||
pass
|
||||
|
||||
def on_test_start(self) -> None:
|
||||
"""Called when test begins."""
|
||||
pass
|
||||
|
||||
def on_predict_start(self) -> None:
|
||||
"""Called when predict begins."""
|
||||
pass
|
||||
|
||||
def on_train_end(self) -> None:
|
||||
"""Called when train ends."""
|
||||
pass
|
||||
|
||||
def on_validation_end(self) -> None:
|
||||
"""Called when validation ends."""
|
||||
pass
|
||||
|
||||
def on_test_end(self) -> None:
|
||||
"""Called when test end."""
|
||||
pass
|
||||
|
||||
def on_predict_end(self):
|
||||
"""Called when predict ends."""
|
||||
pass
|
||||
|
||||
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
"""
|
||||
Called in the training loop before anything happens for that batch.
|
||||
"""
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue