parent
7beed7cae6
commit
ad80a7d638
|
@ -3,11 +3,47 @@
|
|||
|
||||
Callbacks
|
||||
=========
|
||||
|
||||
Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL
|
||||
logic that is NOT required for your LightningModule to run.
|
||||
|
||||
An overall Lightning system should have:
|
||||
|
||||
1. Trainer for all engineering
|
||||
2. LightningModule for all research code.
|
||||
3. Callbacks for non-essential code.
|
||||
|
||||
Example
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
class MyPrintingCallback(pl.Callback):
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
print('Starting to init trainer!')
|
||||
|
||||
def on_init_end(self, trainer):
|
||||
print('trainer is init now')
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
print('do something when training ends')
|
||||
|
||||
# pass to trainer
|
||||
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
|
||||
|
||||
We successfully extended functionality without polluting our super clean LightningModule research code
|
||||
|
||||
Callback Class
|
||||
--------------
|
||||
|
||||
.. automodule:: pytorch_lightning.callbacks
|
||||
:noindex:
|
||||
:exclude-members:
|
||||
_del_model,
|
||||
_save_model,
|
||||
_abc_impl,
|
||||
on_epoch_end,
|
||||
on_train_end,
|
||||
on_epoch_start,
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
Hooks
|
||||
-----
|
||||
|
||||
.. automodule:: pytorch_lightning.core.hooks
|
||||
|
||||
Full list of hooks
|
||||
------------------
|
||||
|
||||
|
||||
Training set-up
|
||||
===============
|
||||
================
|
||||
- init_ddp_connection
|
||||
- init_optimizers
|
||||
- configure_apex
|
||||
|
|
|
@ -11,11 +11,11 @@ import abc
|
|||
class Callback(abc.ABC):
|
||||
"""Abstract base class used to build new callbacks."""
|
||||
|
||||
def on_init_start(self, trainer, pl_module):
|
||||
def on_init_start(self, trainer):
|
||||
"""Called when the trainer initialization begins."""
|
||||
assert pl_module is None
|
||||
pass
|
||||
|
||||
def on_init_end(self, trainer, pl_module):
|
||||
def on_init_end(self, trainer):
|
||||
"""Called when the trainer initialization ends."""
|
||||
pass
|
||||
|
||||
|
|
|
@ -12,15 +12,15 @@ class TrainerCallbackHookMixin(ABC):
|
|||
self.callbacks: list[Callback] = []
|
||||
self.get_model: Callable = ...
|
||||
|
||||
def on_init_start(self):
|
||||
def on_init_start(self, trainer):
|
||||
"""Called when the trainer initialization begins."""
|
||||
for callback in self.callbacks:
|
||||
callback.on_init_start(self, None)
|
||||
callback.on_init_start(trainer)
|
||||
|
||||
def on_init_end(self):
|
||||
def on_init_end(self, trainer):
|
||||
"""Called when the trainer initialization ends."""
|
||||
for callback in self.callbacks:
|
||||
callback.on_init_end(self, self.get_model())
|
||||
callback.on_init_end(trainer)
|
||||
|
||||
def on_fit_start(self):
|
||||
"""Called when the fit begins."""
|
||||
|
|
|
@ -618,7 +618,7 @@ class Trainer(TrainerIOMixin,
|
|||
|
||||
# Init callbacks
|
||||
self.callbacks = callbacks
|
||||
self.on_init_start()
|
||||
self.on_init_start(self)
|
||||
|
||||
# benchmarking
|
||||
self.benchmark = benchmark
|
||||
|
@ -808,7 +808,7 @@ class Trainer(TrainerIOMixin,
|
|||
self.init_amp(use_amp)
|
||||
|
||||
# Callback system
|
||||
self.on_init_end()
|
||||
self.on_init_end(self)
|
||||
|
||||
@property
|
||||
def slurm_job_id(self) -> int:
|
||||
|
|
|
@ -630,10 +630,10 @@ def test_trainer_callback_system(tmpdir):
|
|||
self.on_test_start_called = False
|
||||
self.on_test_end_called = False
|
||||
|
||||
def on_init_start(self, trainer, pl_module):
|
||||
def on_init_start(self, trainer):
|
||||
self.on_init_start_called = True
|
||||
|
||||
def on_init_end(self, trainer, pl_module):
|
||||
def on_init_end(self, trainer):
|
||||
self.on_init_end_called = True
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
|
|
Loading…
Reference in New Issue