From 1f2da710697bd4b090dc3b74bf6a583f3ea3d913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Apr 2020 11:38:52 +0200 Subject: [PATCH] Improved docs for callbacks (#1370) * improved docs for callbacks * class references * make doctest pass * doctests * fix lines too long * fix line too long * fix permission error in doctest * Apply suggestions from code review Co-Authored-By: Jirka Borovec * fix doctest * fix default Co-authored-by: Jirka Borovec --- docs/source/callbacks.rst | 38 +++++++------ docs/source/early_stopping.rst | 27 ++++----- pytorch_lightning/callbacks/base.py | 4 +- pytorch_lightning/callbacks/early_stopping.py | 22 ++++---- .../gradient_accumulation_scheduler.py | 13 +++-- .../callbacks/model_checkpoint.py | 56 ++++++++++--------- 6 files changed, 88 insertions(+), 72 deletions(-) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index 364bf07213..ffb7671b72 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -7,7 +7,7 @@ Callbacks ========= Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL -logic that is NOT required for your LightningModule to run. +logic that is NOT required for your :class:`~pytorch_lightning.core.LightningModule` to run. An overall Lightning system should have: @@ -15,27 +15,29 @@ An overall Lightning system should have: 2. LightningModule for all research code. 3. Callbacks for non-essential code. -Example -.. code-block:: python +Example: - import pytorch_lightning as pl +.. doctest:: - class MyPrintingCallback(pl.Callback): + >>> import pytorch_lightning as pl + >>> class MyPrintingCallback(pl.Callback): + ... + ... def on_init_start(self, trainer): + ... print('Starting to init trainer!') + ... + ... def on_init_end(self, trainer): + ... print('trainer is init now') + ... + ... def on_train_end(self, trainer, pl_module): + ... print('do something when training ends') + ... + >>> trainer = pl.Trainer(callbacks=[MyPrintingCallback()]) + Starting to init trainer! + trainer is init now - def on_init_start(self, trainer): - print('Starting to init trainer!') - - def on_init_end(self, trainer): - print('trainer is init now') - - def on_train_end(self, trainer, pl_module): - print('do something when training ends') - - # pass to trainer - trainer = pl.Trainer(callbacks=[MyPrintingCallback()]) - -We successfully extended functionality without polluting our super clean LightningModule research code +We successfully extended functionality without polluting our super clean +:class:`~pytorch_lightning.core.LightningModule` research code. --------- diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst index 585627a3b0..e94cb079a8 100644 --- a/docs/source/early_stopping.rst +++ b/docs/source/early_stopping.rst @@ -11,24 +11,23 @@ Enable Early Stopping --------------------- There are two ways to enable early stopping. -.. seealso:: - :class:`~pytorch_lightning.trainer.trainer.Trainer` +.. doctest:: -.. code-block:: python + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import EarlyStopping # A) Set early_stop_callback to True. Will look for 'val_loss' # in validation_epoch_end() return dict. If it is not found an error is raised. - trainer = Trainer(early_stop_callback=True) - + >>> trainer = Trainer(early_stop_callback=True) # B) Or configure your own callback - early_stop_callback = EarlyStopping( - monitor='val_loss', - min_delta=0.00, - patience=3, - verbose=False, - mode='min' - ) - trainer = Trainer(early_stop_callback=early_stop_callback) + >>> early_stop_callback = EarlyStopping( + ... monitor='val_loss', + ... min_delta=0.00, + ... patience=3, + ... verbose=False, + ... mode='min' + ... ) + >>> trainer = Trainer(early_stop_callback=early_stop_callback) In any case, the callback will fall back to the training metrics (returned in :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`, @@ -37,6 +36,8 @@ looking for a key to monitor if validation is disabled or :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` is not defined. +.. seealso:: + :class:`~pytorch_lightning.trainer.trainer.Trainer` Disable Early Stopping ---------------------- diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index c7898e9dc1..9bf576b0c1 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -1,7 +1,9 @@ r""" Callback Base ============= - Abstract base class used to build new callbacks. + +Abstract base class used to build new callbacks. + """ import abc diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 53d6b4cfb1..f477cd724b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -1,6 +1,7 @@ r""" Early Stopping ============== + Stop training when a monitored quantity has stopped improving. """ @@ -17,31 +18,30 @@ class EarlyStopping(Callback): r""" Args: - monitor (str): quantity to be monitored. Default: ``'val_loss'``. - min_delta (float): minimum change in the monitored quantity + monitor: quantity to be monitored. Default: ``'val_loss'``. + min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than `min_delta`, will count as no improvement. Default: ``0``. - patience (int): number of epochs with no improvement + patience: number of epochs with no improvement after which training will be stopped. Default: ``0``. - verbose (bool): verbosity mode. Default: ``False``. - mode (str): one of {auto, min, max}. In `min` mode, + verbose: verbosity mode. Default: ``False``. + mode: one of {auto, min, max}. In `min` mode, training will stop when the quantity monitored has stopped decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing; in `auto` mode, the direction is automatically inferred from the name of the monitored quantity. Default: ``'auto'``. - strict (bool): whether to crash the training if `monitor` is + strict: whether to crash the training if `monitor` is not found in the metrics. Default: ``True``. Example:: - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import EarlyStopping - - early_stopping = EarlyStopping('val_loss') - Trainer(early_stop_callback=early_stopping) + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import EarlyStopping + >>> early_stopping = EarlyStopping('val_loss') + >>> trainer = Trainer(early_stop_callback=early_stopping) """ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0, diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 29565800d4..b0563f46c8 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -1,7 +1,9 @@ r""" Gradient Accumulator ==================== + Change gradient accumulation factor according to scheduling. + """ import warnings @@ -22,12 +24,15 @@ class GradientAccumulationScheduler(Callback): Example:: - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import GradientAccumulationScheduler + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import GradientAccumulationScheduler # at epoch 5 start accumulating every 2 batches - accumulator = GradientAccumulationScheduler(scheduling: {5: 2}) - Trainer(accumulate_grad_batches=accumulator) + >>> accumulator = GradientAccumulationScheduler(scheduling={5: 2}) + >>> trainer = Trainer(callbacks=[accumulator]) + + # alternatively, pass the scheduling dict directly to the Trainer + >>> trainer = Trainer(accumulate_grad_batches={5: 2}) """ def __init__(self, scheduling: dict): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 54997e9e63..90f649394f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -3,6 +3,7 @@ Model Checkpointing =================== Automatically save model checkpoints during training. + """ import os @@ -26,18 +27,19 @@ class ModelCheckpoint(Callback): Example:: - # no path - ModelCheckpoint() - # saves like /my/path/epoch_0.ckpt + # custom path + # saves a file like: my/path/epoch_0.ckpt + >>> checkpoint_callback = ModelCheckpoint('my/path/') - # save any arbitrary metrics like and val_loss, etc in name - ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}') - # saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt + # save any arbitrary metrics like `val_loss`, etc. in name + # saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}' + ... ) - - monitor (str): quantity to monitor. - verbose (bool): verbosity mode, False or True. - save_top_k (int): if `save_top_k == k`, + monitor: quantity to monitor. + verbose: verbosity mode. Default: ``False``. + save_top_k: if `save_top_k == k`, the best k models according to the quantity monitored will be saved. if ``save_top_k == 0``, no models are saved. @@ -46,7 +48,7 @@ class ModelCheckpoint(Callback): if ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with `v0`. - mode (str): one of {auto, min, max}. + mode: one of {auto, min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the @@ -54,26 +56,29 @@ class ModelCheckpoint(Callback): this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. - save_weights_only (bool): if True, then only the model's weights will be - saved (`model.save_weights(filepath)`), else the full model - is saved (`model.save(filepath)`). - period (int): Interval (number of epochs) between checkpoints. + save_weights_only: if ``True``, then only the model's weights will be + saved (``model.save_weights(filepath)``), else the full model + is saved (``model.save(filepath)``). + period: Interval (number of epochs) between checkpoints. Example:: - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import ModelCheckpoint + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import ModelCheckpoint - # saves checkpoints to my_path whenever 'val_loss' has a new min - checkpoint_callback = ModelCheckpoint(filepath='my_path') - Trainer(checkpoint_callback=checkpoint_callback) + # saves checkpoints to 'my/path/' whenever 'val_loss' has a new min + >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/') + >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) # save epoch and val_loss in name - ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}') - # saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt + # saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}' + ... ) + """ - def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False, + def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() @@ -137,9 +142,10 @@ class ModelCheckpoint(Callback): return self.monitor_op(current, self.best_k_models[self.kth_best_model]) def format_checkpoint_name(self, epoch, metrics, ver=None): - """Generate a filename according define template. + """Generate a filename according to the defined template. + + Example:: - Examples: >>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}')) >>> os.path.basename(ckpt.format_checkpoint_name(0, {}))