diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index 88ebceedbd..07aba64255 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -3,11 +3,47 @@ Callbacks ========= + +Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL +logic that is NOT required for your LightningModule to run. + +An overall Lightning system should have: + +1. Trainer for all engineering +2. LightningModule for all research code. +3. Callbacks for non-essential code. + +Example + +.. code-block:: python + + import pytorch_lightning as pl + + class MyPrintingCallback(pl.Callback): + + def on_init_start(self, trainer): + print('Starting to init trainer!') + + def on_init_end(self, trainer): + print('trainer is init now') + + def on_train_end(self, trainer, pl_module): + print('do something when training ends') + + # pass to trainer + trainer = pl.Trainer(callbacks=[MyPrintingCallback()]) + +We successfully extended functionality without polluting our super clean LightningModule research code + +Callback Class +-------------- + .. automodule:: pytorch_lightning.callbacks :noindex: :exclude-members: _del_model, _save_model, + _abc_impl, on_epoch_end, on_train_end, on_epoch_start, diff --git a/docs/source/hooks.rst b/docs/source/hooks.rst index 56ba52c5d7..eb6d9f6c0d 100644 --- a/docs/source/hooks.rst +++ b/docs/source/hooks.rst @@ -1,10 +1,13 @@ +Hooks +----- + .. automodule:: pytorch_lightning.core.hooks Full list of hooks ------------------- + Training set-up -=============== +================ - init_ddp_connection - init_optimizers - configure_apex diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 17dda597a3..ea06cab756 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -11,11 +11,11 @@ import abc class Callback(abc.ABC): """Abstract base class used to build new callbacks.""" - def on_init_start(self, trainer, pl_module): + def on_init_start(self, trainer): """Called when the trainer initialization begins.""" - assert pl_module is None + pass - def on_init_end(self, trainer, pl_module): + def on_init_end(self, trainer): """Called when the trainer initialization ends.""" pass diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 66bda345ec..f8c2848b07 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -12,15 +12,15 @@ class TrainerCallbackHookMixin(ABC): self.callbacks: list[Callback] = [] self.get_model: Callable = ... - def on_init_start(self): + def on_init_start(self, trainer): """Called when the trainer initialization begins.""" for callback in self.callbacks: - callback.on_init_start(self, None) + callback.on_init_start(trainer) - def on_init_end(self): + def on_init_end(self, trainer): """Called when the trainer initialization ends.""" for callback in self.callbacks: - callback.on_init_end(self, self.get_model()) + callback.on_init_end(trainer) def on_fit_start(self): """Called when the fit begins.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0a350cb62c..ff1860c0ac 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -618,7 +618,7 @@ class Trainer(TrainerIOMixin, # Init callbacks self.callbacks = callbacks - self.on_init_start() + self.on_init_start(self) # benchmarking self.benchmark = benchmark @@ -808,7 +808,7 @@ class Trainer(TrainerIOMixin, self.init_amp(use_amp) # Callback system - self.on_init_end() + self.on_init_end(self) @property def slurm_job_id(self) -> int: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b2f4a99740..9134378c8f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -630,10 +630,10 @@ def test_trainer_callback_system(tmpdir): self.on_test_start_called = False self.on_test_end_called = False - def on_init_start(self, trainer, pl_module): + def on_init_start(self, trainer): self.on_init_start_called = True - def on_init_end(self, trainer, pl_module): + def on_init_end(self, trainer): self.on_init_end_called = True def on_fit_start(self, trainer, pl_module):