Improved docs for callbacks (#1370)
* improved docs for callbacks * class references * make doctest pass * doctests * fix lines too long * fix line too long * fix permission error in doctest * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * fix doctest * fix default Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
22bedf9b57
commit
1f2da71069
|
@ -7,7 +7,7 @@ Callbacks
|
|||
=========
|
||||
|
||||
Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL
|
||||
logic that is NOT required for your LightningModule to run.
|
||||
logic that is NOT required for your :class:`~pytorch_lightning.core.LightningModule` to run.
|
||||
|
||||
An overall Lightning system should have:
|
||||
|
||||
|
@ -15,27 +15,29 @@ An overall Lightning system should have:
|
|||
2. LightningModule for all research code.
|
||||
3. Callbacks for non-essential code.
|
||||
|
||||
Example
|
||||
|
||||
.. code-block:: python
|
||||
Example:
|
||||
|
||||
import pytorch_lightning as pl
|
||||
.. doctest::
|
||||
|
||||
class MyPrintingCallback(pl.Callback):
|
||||
>>> import pytorch_lightning as pl
|
||||
>>> class MyPrintingCallback(pl.Callback):
|
||||
...
|
||||
... def on_init_start(self, trainer):
|
||||
... print('Starting to init trainer!')
|
||||
...
|
||||
... def on_init_end(self, trainer):
|
||||
... print('trainer is init now')
|
||||
...
|
||||
... def on_train_end(self, trainer, pl_module):
|
||||
... print('do something when training ends')
|
||||
...
|
||||
>>> trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
|
||||
Starting to init trainer!
|
||||
trainer is init now
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
print('Starting to init trainer!')
|
||||
|
||||
def on_init_end(self, trainer):
|
||||
print('trainer is init now')
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
print('do something when training ends')
|
||||
|
||||
# pass to trainer
|
||||
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
|
||||
|
||||
We successfully extended functionality without polluting our super clean LightningModule research code
|
||||
We successfully extended functionality without polluting our super clean
|
||||
:class:`~pytorch_lightning.core.LightningModule` research code.
|
||||
|
||||
---------
|
||||
|
||||
|
|
|
@ -11,24 +11,23 @@ Enable Early Stopping
|
|||
---------------------
|
||||
There are two ways to enable early stopping.
|
||||
|
||||
.. seealso::
|
||||
:class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
.. doctest::
|
||||
|
||||
.. code-block:: python
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import EarlyStopping
|
||||
|
||||
# A) Set early_stop_callback to True. Will look for 'val_loss'
|
||||
# in validation_epoch_end() return dict. If it is not found an error is raised.
|
||||
trainer = Trainer(early_stop_callback=True)
|
||||
|
||||
>>> trainer = Trainer(early_stop_callback=True)
|
||||
# B) Or configure your own callback
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
min_delta=0.00,
|
||||
patience=3,
|
||||
verbose=False,
|
||||
mode='min'
|
||||
)
|
||||
trainer = Trainer(early_stop_callback=early_stop_callback)
|
||||
>>> early_stop_callback = EarlyStopping(
|
||||
... monitor='val_loss',
|
||||
... min_delta=0.00,
|
||||
... patience=3,
|
||||
... verbose=False,
|
||||
... mode='min'
|
||||
... )
|
||||
>>> trainer = Trainer(early_stop_callback=early_stop_callback)
|
||||
|
||||
In any case, the callback will fall back to the training metrics (returned in
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
|
||||
|
@ -37,6 +36,8 @@ looking for a key to monitor if validation is disabled or
|
|||
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
|
||||
is not defined.
|
||||
|
||||
.. seealso::
|
||||
:class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
|
||||
Disable Early Stopping
|
||||
----------------------
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
r"""
|
||||
Callback Base
|
||||
=============
|
||||
Abstract base class used to build new callbacks.
|
||||
|
||||
Abstract base class used to build new callbacks.
|
||||
|
||||
"""
|
||||
|
||||
import abc
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
r"""
|
||||
Early Stopping
|
||||
==============
|
||||
|
||||
Stop training when a monitored quantity has stopped improving.
|
||||
|
||||
"""
|
||||
|
@ -17,31 +18,30 @@ class EarlyStopping(Callback):
|
|||
r"""
|
||||
|
||||
Args:
|
||||
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
|
||||
min_delta (float): minimum change in the monitored quantity
|
||||
monitor: quantity to be monitored. Default: ``'val_loss'``.
|
||||
min_delta: minimum change in the monitored quantity
|
||||
to qualify as an improvement, i.e. an absolute
|
||||
change of less than `min_delta`, will count as no
|
||||
improvement. Default: ``0``.
|
||||
patience (int): number of epochs with no improvement
|
||||
patience: number of epochs with no improvement
|
||||
after which training will be stopped. Default: ``0``.
|
||||
verbose (bool): verbosity mode. Default: ``False``.
|
||||
mode (str): one of {auto, min, max}. In `min` mode,
|
||||
verbose: verbosity mode. Default: ``False``.
|
||||
mode: one of {auto, min, max}. In `min` mode,
|
||||
training will stop when the quantity
|
||||
monitored has stopped decreasing; in `max`
|
||||
mode it will stop when the quantity
|
||||
monitored has stopped increasing; in `auto`
|
||||
mode, the direction is automatically inferred
|
||||
from the name of the monitored quantity. Default: ``'auto'``.
|
||||
strict (bool): whether to crash the training if `monitor` is
|
||||
strict: whether to crash the training if `monitor` is
|
||||
not found in the metrics. Default: ``True``.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
|
||||
early_stopping = EarlyStopping('val_loss')
|
||||
Trainer(early_stop_callback=early_stopping)
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import EarlyStopping
|
||||
>>> early_stopping = EarlyStopping('val_loss')
|
||||
>>> trainer = Trainer(early_stop_callback=early_stopping)
|
||||
"""
|
||||
|
||||
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0,
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
r"""
|
||||
Gradient Accumulator
|
||||
====================
|
||||
|
||||
Change gradient accumulation factor according to scheduling.
|
||||
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
@ -22,12 +24,15 @@ class GradientAccumulationScheduler(Callback):
|
|||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
|
||||
# at epoch 5 start accumulating every 2 batches
|
||||
accumulator = GradientAccumulationScheduler(scheduling: {5: 2})
|
||||
Trainer(accumulate_grad_batches=accumulator)
|
||||
>>> accumulator = GradientAccumulationScheduler(scheduling={5: 2})
|
||||
>>> trainer = Trainer(callbacks=[accumulator])
|
||||
|
||||
# alternatively, pass the scheduling dict directly to the Trainer
|
||||
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
|
||||
"""
|
||||
|
||||
def __init__(self, scheduling: dict):
|
||||
|
|
|
@ -3,6 +3,7 @@ Model Checkpointing
|
|||
===================
|
||||
|
||||
Automatically save model checkpoints during training.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
|
@ -26,18 +27,19 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
Example::
|
||||
|
||||
# no path
|
||||
ModelCheckpoint()
|
||||
# saves like /my/path/epoch_0.ckpt
|
||||
# custom path
|
||||
# saves a file like: my/path/epoch_0.ckpt
|
||||
>>> checkpoint_callback = ModelCheckpoint('my/path/')
|
||||
|
||||
# save any arbitrary metrics like and val_loss, etc in name
|
||||
ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}')
|
||||
# saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
|
||||
# save any arbitrary metrics like `val_loss`, etc. in name
|
||||
# saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
|
||||
>>> checkpoint_callback = ModelCheckpoint(
|
||||
... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
|
||||
... )
|
||||
|
||||
|
||||
monitor (str): quantity to monitor.
|
||||
verbose (bool): verbosity mode, False or True.
|
||||
save_top_k (int): if `save_top_k == k`,
|
||||
monitor: quantity to monitor.
|
||||
verbose: verbosity mode. Default: ``False``.
|
||||
save_top_k: if `save_top_k == k`,
|
||||
the best k models according to
|
||||
the quantity monitored will be saved.
|
||||
if ``save_top_k == 0``, no models are saved.
|
||||
|
@ -46,7 +48,7 @@ class ModelCheckpoint(Callback):
|
|||
if ``save_top_k >= 2`` and the callback is called multiple
|
||||
times inside an epoch, the name of the saved file will be
|
||||
appended with a version count starting with `v0`.
|
||||
mode (str): one of {auto, min, max}.
|
||||
mode: one of {auto, min, max}.
|
||||
If ``save_top_k != 0``, the decision
|
||||
to overwrite the current save file is made
|
||||
based on either the maximization or the
|
||||
|
@ -54,26 +56,29 @@ class ModelCheckpoint(Callback):
|
|||
this should be `max`, for `val_loss` this should
|
||||
be `min`, etc. In `auto` mode, the direction is
|
||||
automatically inferred from the name of the monitored quantity.
|
||||
save_weights_only (bool): if True, then only the model's weights will be
|
||||
saved (`model.save_weights(filepath)`), else the full model
|
||||
is saved (`model.save(filepath)`).
|
||||
period (int): Interval (number of epochs) between checkpoints.
|
||||
save_weights_only: if ``True``, then only the model's weights will be
|
||||
saved (``model.save_weights(filepath)``), else the full model
|
||||
is saved (``model.save(filepath)``).
|
||||
period: Interval (number of epochs) between checkpoints.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
# saves checkpoints to my_path whenever 'val_loss' has a new min
|
||||
checkpoint_callback = ModelCheckpoint(filepath='my_path')
|
||||
Trainer(checkpoint_callback=checkpoint_callback)
|
||||
# saves checkpoints to 'my/path/' whenever 'val_loss' has a new min
|
||||
>>> checkpoint_callback = ModelCheckpoint(filepath='my/path/')
|
||||
>>> trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
||||
|
||||
# save epoch and val_loss in name
|
||||
ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}')
|
||||
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
|
||||
# saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
|
||||
>>> checkpoint_callback = ModelCheckpoint(
|
||||
... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
|
||||
... )
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
|
||||
def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = False,
|
||||
save_top_k: int = 1, save_weights_only: bool = False,
|
||||
mode: str = 'auto', period: int = 1, prefix: str = ''):
|
||||
super().__init__()
|
||||
|
@ -137,9 +142,10 @@ class ModelCheckpoint(Callback):
|
|||
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
|
||||
|
||||
def format_checkpoint_name(self, epoch, metrics, ver=None):
|
||||
"""Generate a filename according define template.
|
||||
"""Generate a filename according to the defined template.
|
||||
|
||||
Example::
|
||||
|
||||
Examples:
|
||||
>>> tmpdir = os.path.dirname(__file__)
|
||||
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
|
||||
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
|
||||
|
|
Loading…
Reference in New Issue