[IPU] Add hooks for IPU lifecycle 4/5 (#7864)

This commit is contained in:
Sean Naren 2021-06-07 13:06:41 +01:00 committed by GitHub
parent ea71cf4a5f
commit 41be61c6f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 12 deletions

View File

@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))
### Changed
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)

View File

@ -179,10 +179,6 @@ class Accelerator:
return move_data_to_device(batch, device)
def on_train_start(self) -> None:
"""Hook to do something upon the training start"""
pass
def training_step(
self,
step_kwargs: Dict[str, Union[Any, int]],
@ -348,14 +344,6 @@ class Accelerator:
model=self.model,
)
def on_train_epoch_end(self) -> None:
"""Hook to do something on the end of an training epoch."""
pass
def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass
def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
"""
Creates optimizers and schedulers
@ -563,3 +551,45 @@ class Accelerator:
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step)
def on_train_epoch_end(self) -> None:
"""Hook to do something on the end of an training epoch."""
pass
def on_train_start(self) -> None:
"""Called when train begins."""
return self.training_type_plugin.on_train_start()
def on_validation_start(self) -> None:
"""Called when validation begins."""
return self.training_type_plugin.on_validation_start()
def on_test_start(self) -> None:
"""Called when test begins."""
return self.training_type_plugin.on_test_start()
def on_predict_start(self) -> None:
"""Called when predict begins."""
return self.training_type_plugin.on_predict_start()
def on_validation_end(self) -> None:
"""Called when validation ends."""
return self.training_type_plugin.on_validation_end()
def on_test_end(self) -> None:
"""Called when test end."""
return self.training_type_plugin.on_test_end()
def on_predict_end(self) -> None:
"""Called when predict ends."""
return self.training_type_plugin.on_predict_end()
def on_train_end(self) -> None:
"""Called when train ends."""
return self.training_type_plugin.on_train_end()
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop before anything happens for that batch.
"""
return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx)

View File

@ -330,3 +330,41 @@ class TrainingTypePlugin(Plugin, ABC):
def should_rank_save_checkpoint(self) -> bool:
"""Returns whether the checkpoint should be saved (rank based)"""
return self.is_global_zero
def on_train_start(self) -> None:
"""Called when train begins."""
pass
def on_validation_start(self) -> None:
"""Called when validation begins."""
pass
def on_test_start(self) -> None:
"""Called when test begins."""
pass
def on_predict_start(self) -> None:
"""Called when predict begins."""
pass
def on_train_end(self) -> None:
"""Called when train ends."""
pass
def on_validation_end(self) -> None:
"""Called when validation ends."""
pass
def on_test_end(self) -> None:
"""Called when test end."""
pass
def on_predict_end(self):
"""Called when predict ends."""
pass
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
Called in the training loop before anything happens for that batch.
"""
pass