Improved docs for callbacks (#1370)

* improved docs for callbacks

* class references

* make doctest pass

* doctests

* fix lines too long

* fix line too long

* fix permission error in doctest

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* fix doctest

* fix default

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2020-04-05 11:38:52 +02:00 committed by GitHub
parent 22bedf9b57
commit 1f2da71069
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 72 deletions

View File

@ -7,7 +7,7 @@ Callbacks
=========
Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL
logic that is NOT required for your LightningModule to run.
logic that is NOT required for your :class:`~pytorch_lightning.core.LightningModule` to run.
An overall Lightning system should have:
@ -15,27 +15,29 @@ An overall Lightning system should have:
2. LightningModule for all research code.
3. Callbacks for non-essential code.
Example
.. code-block:: python
Example:
import pytorch_lightning as pl
.. doctest::
class MyPrintingCallback(pl.Callback):
>>> import pytorch_lightning as pl
>>> class MyPrintingCallback(pl.Callback):
...
... def on_init_start(self, trainer):
... print('Starting to init trainer!')
...
... def on_init_end(self, trainer):
... print('trainer is init now')
...
... def on_train_end(self, trainer, pl_module):
... print('do something when training ends')
...
>>> trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
Starting to init trainer!
trainer is init now
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
# pass to trainer
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
We successfully extended functionality without polluting our super clean LightningModule research code
We successfully extended functionality without polluting our super clean
:class:`~pytorch_lightning.core.LightningModule` research code.
---------

View File

@ -11,24 +11,23 @@ Enable Early Stopping
---------------------
There are two ways to enable early stopping.
.. seealso::
:class:`~pytorch_lightning.trainer.trainer.Trainer`
.. doctest::
.. code-block:: python
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
# A) Set early_stop_callback to True. Will look for 'val_loss'
# in validation_epoch_end() return dict. If it is not found an error is raised.
trainer = Trainer(early_stop_callback=True)
>>> trainer = Trainer(early_stop_callback=True)
# B) Or configure your own callback
early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.00,
patience=3,
verbose=False,
mode='min'
)
trainer = Trainer(early_stop_callback=early_stop_callback)
>>> early_stop_callback = EarlyStopping(
... monitor='val_loss',
... min_delta=0.00,
... patience=3,
... verbose=False,
... mode='min'
... )
>>> trainer = Trainer(early_stop_callback=early_stop_callback)
In any case, the callback will fall back to the training metrics (returned in
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
@ -37,6 +36,8 @@ looking for a key to monitor if validation is disabled or
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
is not defined.
.. seealso::
:class:`~pytorch_lightning.trainer.trainer.Trainer`
Disable Early Stopping
----------------------

View File

@ -1,7 +1,9 @@
r"""
Callback Base
=============
Abstract base class used to build new callbacks.
Abstract base class used to build new callbacks.
"""
import abc

View File

@ -1,6 +1,7 @@
r"""
Early Stopping
==============
Stop training when a monitored quantity has stopped improving.
"""
@ -17,31 +18,30 @@ class EarlyStopping(Callback):
r"""
Args:
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
min_delta (float): minimum change in the monitored quantity
monitor: quantity to be monitored. Default: ``'val_loss'``.
min_delta: minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience (int): number of epochs with no improvement
patience: number of epochs with no improvement
after which training will be stopped. Default: ``0``.
verbose (bool): verbosity mode. Default: ``False``.
mode (str): one of {auto, min, max}. In `min` mode,
verbose: verbosity mode. Default: ``False``.
mode: one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``.
strict (bool): whether to crash the training if `monitor` is
strict: whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
early_stopping = EarlyStopping('val_loss')
Trainer(early_stop_callback=early_stopping)
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(early_stop_callback=early_stopping)
"""
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0,

View File

@ -1,7 +1,9 @@
r"""
Gradient Accumulator
====================
Change gradient accumulation factor according to scheduling.
"""
import warnings
@ -22,12 +24,15 @@ class GradientAccumulationScheduler(Callback):
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import GradientAccumulationScheduler
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler
# at epoch 5 start accumulating every 2 batches
accumulator = GradientAccumulationScheduler(scheduling: {5: 2})
Trainer(accumulate_grad_batches=accumulator)
>>> accumulator = GradientAccumulationScheduler(scheduling={5: 2})
>>> trainer = Trainer(callbacks=[accumulator])
# alternatively, pass the scheduling dict directly to the Trainer
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
"""
def __init__(self, scheduling: dict):

View File

@ -3,6 +3,7 @@ Model Checkpointing
===================
Automatically save model checkpoints during training.
"""
import os
@ -26,18 +27,19 @@ class ModelCheckpoint(Callback):
Example::
# no path
ModelCheckpoint()
# saves like /my/path/epoch_0.ckpt
# custom path
# saves a file like: my/path/epoch_0.ckpt
>>> checkpoint_callback = ModelCheckpoint('my/path/')
# save any arbitrary metrics like and val_loss, etc in name
ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}')
# saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
# save any arbitrary metrics like `val_loss`, etc. in name
# saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, False or True.
save_top_k (int): if `save_top_k == k`,
monitor: quantity to monitor.
verbose: verbosity mode. Default: ``False``.
save_top_k: if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
@ -46,7 +48,7 @@ class ModelCheckpoint(Callback):
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}.
mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
@ -54,26 +56,29 @@ class ModelCheckpoint(Callback):
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only (bool): if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period (int): Interval (number of epochs) between checkpoints.
save_weights_only: if ``True``, then only the model's weights will be
saved (``model.save_weights(filepath)``), else the full model
is saved (``model.save(filepath)``).
period: Interval (number of epochs) between checkpoints.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to my_path whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath='my_path')
Trainer(checkpoint_callback=checkpoint_callback)
# saves checkpoints to 'my/path/' whenever 'val_loss' has a new min
>>> checkpoint_callback = ModelCheckpoint(filepath='my/path/')
>>> trainer = Trainer(checkpoint_callback=checkpoint_callback)
# save epoch and val_loss in name
ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}')
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
# saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
... )
"""
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
@ -137,9 +142,10 @@ class ModelCheckpoint(Callback):
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
def format_checkpoint_name(self, epoch, metrics, ver=None):
"""Generate a filename according define template.
"""Generate a filename according to the defined template.
Example::
Examples:
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))