diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index 57f7b8a9a5..cd81608149 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -20,6 +20,7 @@ An overall Lightning system should have: 2. LightningModule for all research code. 3. Callbacks for non-essential code. +| Example: @@ -46,20 +47,7 @@ Example: We successfully extended functionality without polluting our super clean :class:`~pytorch_lightning.core.LightningModule` research code. ----------------- - -Best Practices -------------- -The following are best practices when using/designing callbacks. - -1. Callbacks should be isolated in their functionality. -2. Your callback should not rely on the behavior of other callbacks in order to work properly. -3. Do not manually call methods from the callback. -4. Directly calling methods (eg. `on_validation_end`) is strongly discouraged. -5. Whenever possible, your callbacks should not depend on the order in which they are executed. - - ---------- .. automodule:: pytorch_lightning.callbacks.base :noindex: @@ -112,3 +100,16 @@ The following are best practices when using/designing callbacks. .. automodule:: pytorch_lightning.callbacks.progress :noindex: :exclude-members: + + +---------------- + +Best Practices +-------------- +The following are best practices when using/designing callbacks. + +1. Callbacks should be isolated in their functionality. +2. Your callback should not rely on the behavior of other callbacks in order to work properly. +3. Do not manually call methods from the callback. +4. Directly calling methods (eg. `on_validation_end`) is strongly discouraged. +5. Whenever possible, your callbacks should not depend on the order in which they are executed. diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 7c1d055477..a9c6e1fb52 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -1,6 +1,6 @@ r""" Callback Base -============= +------------- Abstract base class used to build new callbacks. diff --git a/pytorch_lightning/core/__init__.py b/pytorch_lightning/core/__init__.py index cbb9684487..c05f27e8d3 100644 --- a/pytorch_lightning/core/__init__.py +++ b/pytorch_lightning/core/__init__.py @@ -212,6 +212,63 @@ don't run your test data by accident. Instead you have to explicitly call: trainer = Trainer() trainer.test(model, test_dataloaders=test_dataloader) +------------- + +TrainResult +^^^^^^^^^^^ +When you are using the `_step_end` and `_epoch_end` only for aggregating metrics and then logging, +consider using either a `EvalResult` or `TrainResult` instead. + +Here's a training loop structure + +.. code-block:: python + + def training_step(self, batch, batch_idx): + return {'loss': loss} + + def training_epoch_end(self, training_step_outputs): + epoch_loss = torch.stack([x['loss'] for x in training_step_outputs]).mean() + return { + 'log': {'epoch_loss': epoch_loss}, + 'progress_bar': {'epoch_loss': epoch_loss} + } + +using the equivalent syntax via the `TrainResult` object: + +.. code-block:: python + + def training_step(self, batch_subset, batch_idx): + loss = ... + result = pl.TrainResult(minimize=loss) + result.log('train_loss', loss, prog_bar=True) + return result + +EvalResult +^^^^^^^^^^ +Same for val/test loop + +.. code-block:: python + + def validation_step(self, batch, batch_idx): + return {'some_metric': some_metric} + + def validation_epoch_end(self, validation_step_outputs): + some_metric_mean = torch.stack([x['some_metric'] for x in validation_step_outputs]).mean() + return { + 'log': {'some_metric_mean': some_metric_mean}, + 'progress_bar': {'some_metric_mean': some_metric_mean} + } + +With the equivalent using the `EvalResult` syntax + +.. code-block:: python + + def validation_step(self, batch, batch_idx): + some_metric = ... + result = pl.EvalResult(checkpoint_on=some_metric) + result.log('some_metric', some_metric, prog_bar=True) + return result + ---------- Training_step_end method