clean docs (#967)

* clean docs

* clean docs

* clean docs
This commit is contained in:
William Falcon 2020-02-27 17:21:51 -05:00 committed by GitHub
parent 7beed7cae6
commit ad80a7d638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 52 additions and 13 deletions

View File

@ -3,11 +3,47 @@
Callbacks
=========
Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL
logic that is NOT required for your LightningModule to run.
An overall Lightning system should have:
1. Trainer for all engineering
2. LightningModule for all research code.
3. Callbacks for non-essential code.
Example
.. code-block:: python
import pytorch_lightning as pl
class MyPrintingCallback(pl.Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
# pass to trainer
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
We successfully extended functionality without polluting our super clean LightningModule research code
Callback Class
--------------
.. automodule:: pytorch_lightning.callbacks
:noindex:
:exclude-members:
_del_model,
_save_model,
_abc_impl,
on_epoch_end,
on_train_end,
on_epoch_start,

View File

@ -1,10 +1,13 @@
Hooks
-----
.. automodule:: pytorch_lightning.core.hooks
Full list of hooks
------------------
Training set-up
===============
================
- init_ddp_connection
- init_optimizers
- configure_apex

View File

@ -11,11 +11,11 @@ import abc
class Callback(abc.ABC):
"""Abstract base class used to build new callbacks."""
def on_init_start(self, trainer, pl_module):
def on_init_start(self, trainer):
"""Called when the trainer initialization begins."""
assert pl_module is None
pass
def on_init_end(self, trainer, pl_module):
def on_init_end(self, trainer):
"""Called when the trainer initialization ends."""
pass

View File

@ -12,15 +12,15 @@ class TrainerCallbackHookMixin(ABC):
self.callbacks: list[Callback] = []
self.get_model: Callable = ...
def on_init_start(self):
def on_init_start(self, trainer):
"""Called when the trainer initialization begins."""
for callback in self.callbacks:
callback.on_init_start(self, None)
callback.on_init_start(trainer)
def on_init_end(self):
def on_init_end(self, trainer):
"""Called when the trainer initialization ends."""
for callback in self.callbacks:
callback.on_init_end(self, self.get_model())
callback.on_init_end(trainer)
def on_fit_start(self):
"""Called when the fit begins."""

View File

@ -618,7 +618,7 @@ class Trainer(TrainerIOMixin,
# Init callbacks
self.callbacks = callbacks
self.on_init_start()
self.on_init_start(self)
# benchmarking
self.benchmark = benchmark
@ -808,7 +808,7 @@ class Trainer(TrainerIOMixin,
self.init_amp(use_amp)
# Callback system
self.on_init_end()
self.on_init_end(self)
@property
def slurm_job_id(self) -> int:

View File

@ -630,10 +630,10 @@ def test_trainer_callback_system(tmpdir):
self.on_test_start_called = False
self.on_test_end_called = False
def on_init_start(self, trainer, pl_module):
def on_init_start(self, trainer):
self.on_init_start_called = True
def on_init_end(self, trainer, pl_module):
def on_init_end(self, trainer):
self.on_init_end_called = True
def on_fit_start(self, trainer, pl_module):