Improved docs for callbacks (#1370)
* improved docs for callbacks * class references * make doctest pass * doctests * fix lines too long * fix line too long * fix permission error in doctest * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * fix doctest * fix default Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
22bedf9b57
commit
1f2da71069
|
@ -7,7 +7,7 @@ Callbacks
|
||||||
=========
|
=========
|
||||||
|
|
||||||
Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL
|
Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL
|
||||||
logic that is NOT required for your LightningModule to run.
|
logic that is NOT required for your :class:`~pytorch_lightning.core.LightningModule` to run.
|
||||||
|
|
||||||
An overall Lightning system should have:
|
An overall Lightning system should have:
|
||||||
|
|
||||||
|
@ -15,27 +15,29 @@ An overall Lightning system should have:
|
||||||
2. LightningModule for all research code.
|
2. LightningModule for all research code.
|
||||||
3. Callbacks for non-essential code.
|
3. Callbacks for non-essential code.
|
||||||
|
|
||||||
Example
|
|
||||||
|
|
||||||
.. code-block:: python
|
Example:
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
.. doctest::
|
||||||
|
|
||||||
class MyPrintingCallback(pl.Callback):
|
>>> import pytorch_lightning as pl
|
||||||
|
>>> class MyPrintingCallback(pl.Callback):
|
||||||
|
...
|
||||||
|
... def on_init_start(self, trainer):
|
||||||
|
... print('Starting to init trainer!')
|
||||||
|
...
|
||||||
|
... def on_init_end(self, trainer):
|
||||||
|
... print('trainer is init now')
|
||||||
|
...
|
||||||
|
... def on_train_end(self, trainer, pl_module):
|
||||||
|
... print('do something when training ends')
|
||||||
|
...
|
||||||
|
>>> trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
|
||||||
|
Starting to init trainer!
|
||||||
|
trainer is init now
|
||||||
|
|
||||||
def on_init_start(self, trainer):
|
We successfully extended functionality without polluting our super clean
|
||||||
print('Starting to init trainer!')
|
:class:`~pytorch_lightning.core.LightningModule` research code.
|
||||||
|
|
||||||
def on_init_end(self, trainer):
|
|
||||||
print('trainer is init now')
|
|
||||||
|
|
||||||
def on_train_end(self, trainer, pl_module):
|
|
||||||
print('do something when training ends')
|
|
||||||
|
|
||||||
# pass to trainer
|
|
||||||
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
|
|
||||||
|
|
||||||
We successfully extended functionality without polluting our super clean LightningModule research code
|
|
||||||
|
|
||||||
---------
|
---------
|
||||||
|
|
||||||
|
|
|
@ -11,24 +11,23 @@ Enable Early Stopping
|
||||||
---------------------
|
---------------------
|
||||||
There are two ways to enable early stopping.
|
There are two ways to enable early stopping.
|
||||||
|
|
||||||
.. seealso::
|
.. doctest::
|
||||||
:class:`~pytorch_lightning.trainer.trainer.Trainer`
|
|
||||||
|
|
||||||
.. code-block:: python
|
>>> from pytorch_lightning import Trainer
|
||||||
|
>>> from pytorch_lightning.callbacks import EarlyStopping
|
||||||
|
|
||||||
# A) Set early_stop_callback to True. Will look for 'val_loss'
|
# A) Set early_stop_callback to True. Will look for 'val_loss'
|
||||||
# in validation_epoch_end() return dict. If it is not found an error is raised.
|
# in validation_epoch_end() return dict. If it is not found an error is raised.
|
||||||
trainer = Trainer(early_stop_callback=True)
|
>>> trainer = Trainer(early_stop_callback=True)
|
||||||
|
|
||||||
# B) Or configure your own callback
|
# B) Or configure your own callback
|
||||||
early_stop_callback = EarlyStopping(
|
>>> early_stop_callback = EarlyStopping(
|
||||||
monitor='val_loss',
|
... monitor='val_loss',
|
||||||
min_delta=0.00,
|
... min_delta=0.00,
|
||||||
patience=3,
|
... patience=3,
|
||||||
verbose=False,
|
... verbose=False,
|
||||||
mode='min'
|
... mode='min'
|
||||||
)
|
... )
|
||||||
trainer = Trainer(early_stop_callback=early_stop_callback)
|
>>> trainer = Trainer(early_stop_callback=early_stop_callback)
|
||||||
|
|
||||||
In any case, the callback will fall back to the training metrics (returned in
|
In any case, the callback will fall back to the training metrics (returned in
|
||||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
|
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
|
||||||
|
@ -37,6 +36,8 @@ looking for a key to monitor if validation is disabled or
|
||||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
|
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
|
||||||
is not defined.
|
is not defined.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
:class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||||
|
|
||||||
Disable Early Stopping
|
Disable Early Stopping
|
||||||
----------------------
|
----------------------
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
r"""
|
r"""
|
||||||
Callback Base
|
Callback Base
|
||||||
=============
|
=============
|
||||||
Abstract base class used to build new callbacks.
|
|
||||||
|
Abstract base class used to build new callbacks.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
r"""
|
r"""
|
||||||
Early Stopping
|
Early Stopping
|
||||||
==============
|
==============
|
||||||
|
|
||||||
Stop training when a monitored quantity has stopped improving.
|
Stop training when a monitored quantity has stopped improving.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -17,31 +18,30 @@ class EarlyStopping(Callback):
|
||||||
r"""
|
r"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
|
monitor: quantity to be monitored. Default: ``'val_loss'``.
|
||||||
min_delta (float): minimum change in the monitored quantity
|
min_delta: minimum change in the monitored quantity
|
||||||
to qualify as an improvement, i.e. an absolute
|
to qualify as an improvement, i.e. an absolute
|
||||||
change of less than `min_delta`, will count as no
|
change of less than `min_delta`, will count as no
|
||||||
improvement. Default: ``0``.
|
improvement. Default: ``0``.
|
||||||
patience (int): number of epochs with no improvement
|
patience: number of epochs with no improvement
|
||||||
after which training will be stopped. Default: ``0``.
|
after which training will be stopped. Default: ``0``.
|
||||||
verbose (bool): verbosity mode. Default: ``False``.
|
verbose: verbosity mode. Default: ``False``.
|
||||||
mode (str): one of {auto, min, max}. In `min` mode,
|
mode: one of {auto, min, max}. In `min` mode,
|
||||||
training will stop when the quantity
|
training will stop when the quantity
|
||||||
monitored has stopped decreasing; in `max`
|
monitored has stopped decreasing; in `max`
|
||||||
mode it will stop when the quantity
|
mode it will stop when the quantity
|
||||||
monitored has stopped increasing; in `auto`
|
monitored has stopped increasing; in `auto`
|
||||||
mode, the direction is automatically inferred
|
mode, the direction is automatically inferred
|
||||||
from the name of the monitored quantity. Default: ``'auto'``.
|
from the name of the monitored quantity. Default: ``'auto'``.
|
||||||
strict (bool): whether to crash the training if `monitor` is
|
strict: whether to crash the training if `monitor` is
|
||||||
not found in the metrics. Default: ``True``.
|
not found in the metrics. Default: ``True``.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
>>> from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
>>> from pytorch_lightning.callbacks import EarlyStopping
|
||||||
|
>>> early_stopping = EarlyStopping('val_loss')
|
||||||
early_stopping = EarlyStopping('val_loss')
|
>>> trainer = Trainer(early_stop_callback=early_stopping)
|
||||||
Trainer(early_stop_callback=early_stopping)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0,
|
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0,
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
r"""
|
r"""
|
||||||
Gradient Accumulator
|
Gradient Accumulator
|
||||||
====================
|
====================
|
||||||
|
|
||||||
Change gradient accumulation factor according to scheduling.
|
Change gradient accumulation factor according to scheduling.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -22,12 +24,15 @@ class GradientAccumulationScheduler(Callback):
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
>>> from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||||
|
|
||||||
# at epoch 5 start accumulating every 2 batches
|
# at epoch 5 start accumulating every 2 batches
|
||||||
accumulator = GradientAccumulationScheduler(scheduling: {5: 2})
|
>>> accumulator = GradientAccumulationScheduler(scheduling={5: 2})
|
||||||
Trainer(accumulate_grad_batches=accumulator)
|
>>> trainer = Trainer(callbacks=[accumulator])
|
||||||
|
|
||||||
|
# alternatively, pass the scheduling dict directly to the Trainer
|
||||||
|
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scheduling: dict):
|
def __init__(self, scheduling: dict):
|
||||||
|
|
|
@ -3,6 +3,7 @@ Model Checkpointing
|
||||||
===================
|
===================
|
||||||
|
|
||||||
Automatically save model checkpoints during training.
|
Automatically save model checkpoints during training.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
@ -26,18 +27,19 @@ class ModelCheckpoint(Callback):
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
# no path
|
# custom path
|
||||||
ModelCheckpoint()
|
# saves a file like: my/path/epoch_0.ckpt
|
||||||
# saves like /my/path/epoch_0.ckpt
|
>>> checkpoint_callback = ModelCheckpoint('my/path/')
|
||||||
|
|
||||||
# save any arbitrary metrics like and val_loss, etc in name
|
# save any arbitrary metrics like `val_loss`, etc. in name
|
||||||
ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}')
|
# saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
|
||||||
# saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
|
>>> checkpoint_callback = ModelCheckpoint(
|
||||||
|
... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
|
||||||
|
... )
|
||||||
|
|
||||||
|
monitor: quantity to monitor.
|
||||||
monitor (str): quantity to monitor.
|
verbose: verbosity mode. Default: ``False``.
|
||||||
verbose (bool): verbosity mode, False or True.
|
save_top_k: if `save_top_k == k`,
|
||||||
save_top_k (int): if `save_top_k == k`,
|
|
||||||
the best k models according to
|
the best k models according to
|
||||||
the quantity monitored will be saved.
|
the quantity monitored will be saved.
|
||||||
if ``save_top_k == 0``, no models are saved.
|
if ``save_top_k == 0``, no models are saved.
|
||||||
|
@ -46,7 +48,7 @@ class ModelCheckpoint(Callback):
|
||||||
if ``save_top_k >= 2`` and the callback is called multiple
|
if ``save_top_k >= 2`` and the callback is called multiple
|
||||||
times inside an epoch, the name of the saved file will be
|
times inside an epoch, the name of the saved file will be
|
||||||
appended with a version count starting with `v0`.
|
appended with a version count starting with `v0`.
|
||||||
mode (str): one of {auto, min, max}.
|
mode: one of {auto, min, max}.
|
||||||
If ``save_top_k != 0``, the decision
|
If ``save_top_k != 0``, the decision
|
||||||
to overwrite the current save file is made
|
to overwrite the current save file is made
|
||||||
based on either the maximization or the
|
based on either the maximization or the
|
||||||
|
@ -54,26 +56,29 @@ class ModelCheckpoint(Callback):
|
||||||
this should be `max`, for `val_loss` this should
|
this should be `max`, for `val_loss` this should
|
||||||
be `min`, etc. In `auto` mode, the direction is
|
be `min`, etc. In `auto` mode, the direction is
|
||||||
automatically inferred from the name of the monitored quantity.
|
automatically inferred from the name of the monitored quantity.
|
||||||
save_weights_only (bool): if True, then only the model's weights will be
|
save_weights_only: if ``True``, then only the model's weights will be
|
||||||
saved (`model.save_weights(filepath)`), else the full model
|
saved (``model.save_weights(filepath)``), else the full model
|
||||||
is saved (`model.save(filepath)`).
|
is saved (``model.save(filepath)``).
|
||||||
period (int): Interval (number of epochs) between checkpoints.
|
period: Interval (number of epochs) between checkpoints.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
>>> from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
>>> from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
|
|
||||||
# saves checkpoints to my_path whenever 'val_loss' has a new min
|
# saves checkpoints to 'my/path/' whenever 'val_loss' has a new min
|
||||||
checkpoint_callback = ModelCheckpoint(filepath='my_path')
|
>>> checkpoint_callback = ModelCheckpoint(filepath='my/path/')
|
||||||
Trainer(checkpoint_callback=checkpoint_callback)
|
>>> trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
||||||
|
|
||||||
# save epoch and val_loss in name
|
# save epoch and val_loss in name
|
||||||
ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}')
|
# saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
|
||||||
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
|
>>> checkpoint_callback = ModelCheckpoint(
|
||||||
|
... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
|
||||||
|
... )
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
|
def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = False,
|
||||||
save_top_k: int = 1, save_weights_only: bool = False,
|
save_top_k: int = 1, save_weights_only: bool = False,
|
||||||
mode: str = 'auto', period: int = 1, prefix: str = ''):
|
mode: str = 'auto', period: int = 1, prefix: str = ''):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -137,9 +142,10 @@ class ModelCheckpoint(Callback):
|
||||||
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
|
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
|
||||||
|
|
||||||
def format_checkpoint_name(self, epoch, metrics, ver=None):
|
def format_checkpoint_name(self, epoch, metrics, ver=None):
|
||||||
"""Generate a filename according define template.
|
"""Generate a filename according to the defined template.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> tmpdir = os.path.dirname(__file__)
|
>>> tmpdir = os.path.dirname(__file__)
|
||||||
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
|
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
|
||||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
|
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
|
||||||
|
|
Loading…
Reference in New Issue