Improved docs for callbacks (#1370)

* improved docs for callbacks

* class references

* make doctest pass

* doctests

* fix lines too long

* fix line too long

* fix permission error in doctest

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* fix doctest

* fix default

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2020-04-05 11:38:52 +02:00 committed by GitHub
parent 22bedf9b57
commit 1f2da71069
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 72 deletions

View File

@ -7,7 +7,7 @@ Callbacks
========= =========
Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL
logic that is NOT required for your LightningModule to run. logic that is NOT required for your :class:`~pytorch_lightning.core.LightningModule` to run.
An overall Lightning system should have: An overall Lightning system should have:
@ -15,27 +15,29 @@ An overall Lightning system should have:
2. LightningModule for all research code. 2. LightningModule for all research code.
3. Callbacks for non-essential code. 3. Callbacks for non-essential code.
Example
.. code-block:: python Example:
import pytorch_lightning as pl .. doctest::
class MyPrintingCallback(pl.Callback): >>> import pytorch_lightning as pl
>>> class MyPrintingCallback(pl.Callback):
...
... def on_init_start(self, trainer):
... print('Starting to init trainer!')
...
... def on_init_end(self, trainer):
... print('trainer is init now')
...
... def on_train_end(self, trainer, pl_module):
... print('do something when training ends')
...
>>> trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
Starting to init trainer!
trainer is init now
def on_init_start(self, trainer): We successfully extended functionality without polluting our super clean
print('Starting to init trainer!') :class:`~pytorch_lightning.core.LightningModule` research code.
def on_init_end(self, trainer):
print('trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
# pass to trainer
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
We successfully extended functionality without polluting our super clean LightningModule research code
--------- ---------

View File

@ -11,24 +11,23 @@ Enable Early Stopping
--------------------- ---------------------
There are two ways to enable early stopping. There are two ways to enable early stopping.
.. seealso:: .. doctest::
:class:`~pytorch_lightning.trainer.trainer.Trainer`
.. code-block:: python >>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
# A) Set early_stop_callback to True. Will look for 'val_loss' # A) Set early_stop_callback to True. Will look for 'val_loss'
# in validation_epoch_end() return dict. If it is not found an error is raised. # in validation_epoch_end() return dict. If it is not found an error is raised.
trainer = Trainer(early_stop_callback=True) >>> trainer = Trainer(early_stop_callback=True)
# B) Or configure your own callback # B) Or configure your own callback
early_stop_callback = EarlyStopping( >>> early_stop_callback = EarlyStopping(
monitor='val_loss', ... monitor='val_loss',
min_delta=0.00, ... min_delta=0.00,
patience=3, ... patience=3,
verbose=False, ... verbose=False,
mode='min' ... mode='min'
) ... )
trainer = Trainer(early_stop_callback=early_stop_callback) >>> trainer = Trainer(early_stop_callback=early_stop_callback)
In any case, the callback will fall back to the training metrics (returned in In any case, the callback will fall back to the training metrics (returned in
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`, :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
@ -37,6 +36,8 @@ looking for a key to monitor if validation is disabled or
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
is not defined. is not defined.
.. seealso::
:class:`~pytorch_lightning.trainer.trainer.Trainer`
Disable Early Stopping Disable Early Stopping
---------------------- ----------------------

View File

@ -1,7 +1,9 @@
r""" r"""
Callback Base Callback Base
============= =============
Abstract base class used to build new callbacks.
Abstract base class used to build new callbacks.
""" """
import abc import abc

View File

@ -1,6 +1,7 @@
r""" r"""
Early Stopping Early Stopping
============== ==============
Stop training when a monitored quantity has stopped improving. Stop training when a monitored quantity has stopped improving.
""" """
@ -17,31 +18,30 @@ class EarlyStopping(Callback):
r""" r"""
Args: Args:
monitor (str): quantity to be monitored. Default: ``'val_loss'``. monitor: quantity to be monitored. Default: ``'val_loss'``.
min_delta (float): minimum change in the monitored quantity min_delta: minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no change of less than `min_delta`, will count as no
improvement. Default: ``0``. improvement. Default: ``0``.
patience (int): number of epochs with no improvement patience: number of epochs with no improvement
after which training will be stopped. Default: ``0``. after which training will be stopped. Default: ``0``.
verbose (bool): verbosity mode. Default: ``False``. verbose: verbosity mode. Default: ``False``.
mode (str): one of {auto, min, max}. In `min` mode, mode: one of {auto, min, max}. In `min` mode,
training will stop when the quantity training will stop when the quantity
monitored has stopped decreasing; in `max` monitored has stopped decreasing; in `max`
mode it will stop when the quantity mode it will stop when the quantity
monitored has stopped increasing; in `auto` monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``. from the name of the monitored quantity. Default: ``'auto'``.
strict (bool): whether to crash the training if `monitor` is strict: whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``. not found in the metrics. Default: ``True``.
Example:: Example::
from pytorch_lightning import Trainer >>> from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping >>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(early_stop_callback=early_stopping)
Trainer(early_stop_callback=early_stopping)
""" """
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0, def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0,

View File

@ -1,7 +1,9 @@
r""" r"""
Gradient Accumulator Gradient Accumulator
==================== ====================
Change gradient accumulation factor according to scheduling. Change gradient accumulation factor according to scheduling.
""" """
import warnings import warnings
@ -22,12 +24,15 @@ class GradientAccumulationScheduler(Callback):
Example:: Example::
from pytorch_lightning import Trainer >>> from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import GradientAccumulationScheduler >>> from pytorch_lightning.callbacks import GradientAccumulationScheduler
# at epoch 5 start accumulating every 2 batches # at epoch 5 start accumulating every 2 batches
accumulator = GradientAccumulationScheduler(scheduling: {5: 2}) >>> accumulator = GradientAccumulationScheduler(scheduling={5: 2})
Trainer(accumulate_grad_batches=accumulator) >>> trainer = Trainer(callbacks=[accumulator])
# alternatively, pass the scheduling dict directly to the Trainer
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
""" """
def __init__(self, scheduling: dict): def __init__(self, scheduling: dict):

View File

@ -3,6 +3,7 @@ Model Checkpointing
=================== ===================
Automatically save model checkpoints during training. Automatically save model checkpoints during training.
""" """
import os import os
@ -26,18 +27,19 @@ class ModelCheckpoint(Callback):
Example:: Example::
# no path # custom path
ModelCheckpoint() # saves a file like: my/path/epoch_0.ckpt
# saves like /my/path/epoch_0.ckpt >>> checkpoint_callback = ModelCheckpoint('my/path/')
# save any arbitrary metrics like and val_loss, etc in name # save any arbitrary metrics like `val_loss`, etc. in name
ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}') # saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
# saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt >>> checkpoint_callback = ModelCheckpoint(
... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
monitor: quantity to monitor.
monitor (str): quantity to monitor. verbose: verbosity mode. Default: ``False``.
verbose (bool): verbosity mode, False or True. save_top_k: if `save_top_k == k`,
save_top_k (int): if `save_top_k == k`,
the best k models according to the best k models according to
the quantity monitored will be saved. the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved. if ``save_top_k == 0``, no models are saved.
@ -46,7 +48,7 @@ class ModelCheckpoint(Callback):
if ``save_top_k >= 2`` and the callback is called multiple if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`. appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}. mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision If ``save_top_k != 0``, the decision
to overwrite the current save file is made to overwrite the current save file is made
based on either the maximization or the based on either the maximization or the
@ -54,26 +56,29 @@ class ModelCheckpoint(Callback):
this should be `max`, for `val_loss` this should this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity. automatically inferred from the name of the monitored quantity.
save_weights_only (bool): if True, then only the model's weights will be save_weights_only: if ``True``, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model saved (``model.save_weights(filepath)``), else the full model
is saved (`model.save(filepath)`). is saved (``model.save(filepath)``).
period (int): Interval (number of epochs) between checkpoints. period: Interval (number of epochs) between checkpoints.
Example:: Example::
from pytorch_lightning import Trainer >>> from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint >>> from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to my_path whenever 'val_loss' has a new min # saves checkpoints to 'my/path/' whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath='my_path') >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/')
Trainer(checkpoint_callback=checkpoint_callback) >>> trainer = Trainer(checkpoint_callback=checkpoint_callback)
# save epoch and val_loss in name # save epoch and val_loss in name
ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}') # saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint(
... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
... )
""" """
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False, def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''): mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__() super().__init__()
@ -137,9 +142,10 @@ class ModelCheckpoint(Callback):
return self.monitor_op(current, self.best_k_models[self.kth_best_model]) return self.monitor_op(current, self.best_k_models[self.kth_best_model])
def format_checkpoint_name(self, epoch, metrics, ver=None): def format_checkpoint_name(self, epoch, metrics, ver=None):
"""Generate a filename according define template. """Generate a filename according to the defined template.
Example::
Examples:
>>> tmpdir = os.path.dirname(__file__) >>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}')) >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {})) >>> os.path.basename(ckpt.format_checkpoint_name(0, {}))