Update extensions doc (#10778)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Rohit Gupta 2021-11-30 02:08:23 +05:30 committed by GitHub
parent 8bf7f9cce7
commit ce95891f6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 152 additions and 79 deletions

View File

@ -168,7 +168,7 @@ To load a model along with its weights, biases and hyperparameters use the follo
.. code-block:: python
model = MyLightingModule.load_from_checkpoint(PATH)
model = MyLightningModule.load_from_checkpoint(PATH)
print(model.learning_rate)
# prints the learning_rate you used in this checkpoint

View File

@ -3,7 +3,7 @@
############
Accelerators
############
Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, etc). Accelerators
Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, IPUs). Accelerators
also manage distributed communication through :ref:`Plugins` (like DP, DDP, HPC cluster) and
can also be configured to run on arbitrary clusters or to link up to arbitrary
computational strategies like 16-bit precision via AMP and Apex.
@ -26,7 +26,7 @@ One to handle differences from the training routine and one to handle different
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin
accelerator = GPUAccelerator(
precision_plugin=NativeMixedPrecisionPlugin(16, "cuda"),
precision_plugin=NativeMixedPrecisionPlugin(precision=16, device="cuda"),
training_type_plugin=DDPPlugin(),
)
trainer = Trainer(accelerator=accelerator)

View File

@ -21,7 +21,7 @@ Callback
A callback is a self-contained program that can be reused across projects.
Lightning has a callback system to execute callbacks when needed. Callbacks should capture NON-ESSENTIAL
Lightning has a callback system to execute them when needed. Callbacks should capture NON-ESSENTIAL
logic that is NOT required for your :doc:`lightning module <../common/lightning_module>` to run.
Here's the flow of how the callback hooks are executed:
@ -47,10 +47,10 @@ Example:
class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print("Starting to init trainer!")
print("Starting to initialize the trainer!")
def on_init_end(self, trainer):
print("trainer is init now")
print("trainer is initialized now")
def on_train_end(self, trainer, pl_module):
print("do something when training ends")
@ -60,8 +60,8 @@ Example:
.. testoutput::
Starting to init trainer!
trainer is init now
Starting to initialize the trainer!
trainer is initialized now
We successfully extended functionality without polluting our super clean
:doc:`lightning module <../common/lightning_module>` research code.
@ -86,7 +86,7 @@ Lightning has a few built-in callbacks.
.. note::
For a richer collection of callbacks, check out our
`bolts library <https://lightning-bolts.readthedocs.io/en/latest/callbacks.html>`_.
`bolts library <https://lightning-bolts.readthedocs.io/en/latest/index.html>`_.
.. currentmodule:: pytorch_lightning.callbacks
@ -108,12 +108,13 @@ Lightning has a few built-in callbacks.
ModelCheckpoint
ModelPruning
ModelSummary
ProgressBar
ProgressBarBase
QuantizationAwareTraining
RichModelSummary
RichProgressBar
QuantizationAwareTraining
StochasticWeightAveraging
Timer
TQDMProgressBar
XLAStatsMonitor
----------
@ -197,245 +198,310 @@ The following are best practices when using/designing callbacks.
.. _hooks:
Available Callback hooks
------------------------
Callback API
------------
Here is the full API of methods available in the Callback base class.
The :class:`~pytorch_lightning.callbacks.Callback` class is the base for all the callbacks in Lightning just like the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base for all models.
It defines a public interface that each callback implementation must follow, the key ones are:
Properties
^^^^^^^^^^
state_key
~~~~~~~~~
.. autoattribute:: pytorch_lightning.callbacks.Callback.state_key
:noindex:
Hooks
^^^^^
on_configure_sharded_model
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_configure_sharded_model
:noindex:
on_before_accelerator_backend_setup
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_accelerator_backend_setup
:noindex:
setup
^^^^^
~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.setup
:noindex:
teardown
^^^^^^^^
~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.teardown
:noindex:
on_init_start
^^^^^^^^^^^^^^
~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_init_start
:noindex:
on_init_end
^^^^^^^^^^^
~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_init_end
:noindex:
on_fit_start
^^^^^^^^^^^^
~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_fit_start
:noindex:
on_fit_end
^^^^^^^^^^
~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_fit_end
:noindex:
on_sanity_check_start
^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_sanity_check_start
:noindex:
on_sanity_check_end
^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_sanity_check_end
:noindex:
on_train_batch_start
^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_batch_start
:noindex:
on_train_batch_end
^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_batch_end
:noindex:
on_train_epoch_start
^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_epoch_start
:noindex:
on_train_epoch_end
^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_epoch_end
:noindex:
on_validation_epoch_start
^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_epoch_start
:noindex:
on_validation_epoch_end
^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_epoch_end
:noindex:
on_test_epoch_start
^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_epoch_start
:noindex:
on_test_epoch_end
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_epoch_end
:noindex:
on_predict_epoch_start
~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_epoch_start
:noindex:
on_predict_epoch_end
~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_epoch_end
:noindex:
on_epoch_start
^^^^^^^^^^^^^^
~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_start
:noindex:
on_epoch_end
^^^^^^^^^^^^
~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_end
:noindex:
on_batch_start
^^^^^^^^^^^^^^
~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_start
:noindex:
on_batch_end
~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_end
:noindex:
on_validation_batch_start
^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_batch_start
:noindex:
on_validation_batch_end
^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_batch_end
:noindex:
on_test_batch_start
^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_batch_start
:noindex:
on_test_batch_end
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_batch_end
:noindex:
on_batch_end
^^^^^^^^^^^^
on_predict_batch_start
~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_end
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_batch_start
:noindex:
on_predict_batch_end
~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_batch_end
:noindex:
on_train_start
^^^^^^^^^^^^^^
~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_start
:noindex:
on_train_end
^^^^^^^^^^^^
~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_end
:noindex:
on_pretrain_routine_start
^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_start
:noindex:
on_pretrain_routine_end
^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_end
:noindex:
on_validation_start
^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_start
:noindex:
on_validation_end
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_end
:noindex:
on_test_start
^^^^^^^^^^^^^
~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_start
:noindex:
on_test_end
^^^^^^^^^^^
~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_end
:noindex:
on_predict_start
~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_start
:noindex:
on_predict_end
~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_end
:noindex:
on_keyboard_interrupt
^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_keyboard_interrupt
:noindex:
on_exception
^^^^^^^^^^^^
~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_exception
:noindex:
on_save_checkpoint
^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_save_checkpoint
:noindex:
on_load_checkpoint
^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint
:noindex:
on_before_backward
^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_backward
:noindex:
on_after_backward
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
:noindex:
on_before_optimizer_step
^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_optimizer_step
:noindex:
on_before_zero_grad
^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad
:noindex:

View File

@ -50,7 +50,7 @@ Think of this as swapping out the engine in a car!
----------
Understanding the default Trainer loop
Understanding the Default Trainer Loop
--------------------------------------
The Lightning :class:`~pytorch_lightning.trainer.trainer.Trainer` automates the standard optimization loop which every PyTorch user is familiar with:
@ -75,7 +75,7 @@ The core research logic is simply shifted to the :class:`~pytorch_lightning.core
# loss = loss_function(y_hat, y) moved to training_step
loss = lightning_module.training_step(batch, i)
# Lighting handles automatically:
# Lightning handles automatically:
optimizer.zero_grad()
loss.backward()
optimizer.step()
@ -109,12 +109,12 @@ Defining a loop within a class interface instead of hard-coding a raw Python for
.. _override default loops:
Overriding the default loops
Overriding the default Loops
----------------------------
The fastest way to get started with loops, is to override functionality of an existing loop.
Lightning has 4 main loops it uses: :class:`~pytorch_lightning.loops.fit_loop.FitLoop` for training and validating,
:class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` for testing,
Lightning has 4 main loops which relies on : :class:`~pytorch_lightning.loops.fit_loop.FitLoop` for fitting (training and validating),
:class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` for validating or testing,
:class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop` for predicting.
For simple changes that don't require a custom loop, you can modify each of these loops.
@ -166,11 +166,11 @@ Now simply attach the correct loop in the trainer directly:
# fit() now uses the new FitLoop!
trainer.fit(...)
# the equivalent for validate(), test(), predict()
# the equivalent for validate()
val_loop = CustomValLoop()
trainer = Trainer()
trainer.validate_loop = val_loop
trainer.validate(model)
trainer.validate(...)
Now your code is FULLY flexible and you can still leverage ALL the best parts of Lightning!
@ -179,7 +179,7 @@ Now your code is FULLY flexible and you can still leverage ALL the best parts of
----------
Creating a new loop from scratch
Creating a New Loop From Scratch
--------------------------------
You can also go wild and implement a full loop from scratch by sub-classing the :class:`~pytorch_lightning.loops.base.Loop` base class.
@ -212,8 +212,7 @@ Finally, attach it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
# fit() now uses your fancy loop!
trainer.fit(...)
Now you have full control over the Trainer.
But beware: The power of loop customization comes with great responsibility.
But beware: Loop customization gives you more power and full control over the Trainer and with great power comes great responsibility.
We recommend that you familiarize yourself with :ref:`overriding the default loops <override default loops>` first before you start building a new loop from the ground up.
----------
@ -222,7 +221,7 @@ Loop API
--------
Here is the full API of methods available in the Loop base class.
The :class:`~pytorch_lightning.loops.base.Loop` class is the base for all loops in Lighting just like the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base for all models.
The :class:`~pytorch_lightning.loops.base.Loop` class is the base of all loops in the same way as the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base of all models.
It defines a public interface that each loop implementation must follow, the key ones are:
Properties
@ -348,6 +347,12 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt
It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step).
* - :class:`~pytorch_lightning.loops.optimization.manual_loop.ManualOptimization`
- Substitutes the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` in case of :ref:`manual_optimization` and implements the manual optimization step.
* - :class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop`
- The :class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` is the top-level loop where validation/testing starts.
It simply iterates over each evaluation dataloader from one to the next by calling :code:`EvaluationEpochLoop.run()` in its :code:`advance()` method.
* - :class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop`
- The :class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop` is the top-level loop where prediction starts.
It simply iterates over each prediction dataloader from one to the next by calling :code:`PredictionEpochLoop.run()` in its :code:`advance()` method.
----------
@ -382,6 +387,7 @@ To run the following demo, install Flash and `BaaL <https://github.com/ElementAI
# Implement the research use-case where we mask labels from labelled dataset.
datamodule = ActiveLearningDataModule(
ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
initial_num_labels=5,
val_split=0.1,
)
@ -390,7 +396,8 @@ To run the following demo, install Flash and `BaaL <https://github.com/ElementAI
torch.nn.Dropout(p=0.1),
torch.nn.Linear(512, datamodule.num_classes),
)
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities())
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=Probabilities())
# 3.1 Create the trainer
trainer = flash.Trainer(max_epochs=3)
@ -410,7 +417,7 @@ To run the following demo, install Flash and `BaaL <https://github.com/ElementAI
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Here is the `Active Learning Loop example <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py#L31>`_.
Here is the `Active Learning Loop example <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py>`_.
----------

View File

@ -5,7 +5,7 @@ Loops (Advanced)
.. _persisting loop state:
Persisting the state of loops
Persisting the State of Loops
-----------------------------
.. note::
@ -18,7 +18,7 @@ A powerful property of the class-based loop interface is that it can own an inte
Loop instances can save their state to the checkpoint through corresponding hooks and if implemented accordingly, resume the state of exectuion at the appropriate place.
This design is particularly interesting for fault-tolerant training which is an experimental feature released in Lightning v1.5.
The two hooks :class:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and :class:`~pytorch_lightning.loops.base.Loop.on_load_checkpoint` function very similarly to how LightningModules and Callbacks save and load state.
The two hooks :meth:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and :meth:`~pytorch_lightning.loops.base.Loop.on_load_checkpoint` function very similarly to how LightningModules and Callbacks save and load state.
.. code-block:: python

View File

@ -46,7 +46,7 @@ PrecisionPlugin
Futhermore, for multi-node training Lightning provides cluster environment plugins that allow the advanced user
to configure Lighting to integrate with a :ref:`custom-cluster`.
to configure Lightning to integrate with a :ref:`custom-cluster`.
.. image:: ../_static/images/accelerator/overview.svg

View File

@ -59,7 +59,7 @@ class YieldLoop(OptimizerLoop):
def on_run_start(self, batch, optimizers, batch_idx):
super().on_run_start(batch, optimizers, batch_idx)
if not inspect.isgeneratorfunction(self.trainer.lightning_module.training_step):
raise MisconfigurationException("The LightingModule does not yield anything in the `training_step`.")
raise MisconfigurationException("The `LightningModule` does not yield anything in the `training_step`.")
assert self.trainer.lightning_module.automatic_optimization
# We request the generator once and save it for later

View File

@ -169,6 +169,10 @@ class Callback(abc.ABC):
"""Called when the training batch begins."""
pass
def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the training batch ends."""
pass
def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
@ -223,10 +227,6 @@ class Callback(abc.ABC):
"""Called when the predict batch ends."""
pass
def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the training batch ends."""
pass
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train begins."""
pass

View File

@ -473,7 +473,7 @@ class NeptuneLogger(LightningLoggerBase):
for key, val in metrics.items():
# `step` is ignored because Neptune expects strictly increasing step values which
# Lighting does not always guarantee.
# Lightning does not always guarantee.
self.experiment[key].log(val)
@rank_zero_only

View File

@ -2016,7 +2016,7 @@ class Trainer(
"""Attach a custom fit loop to this Trainer.
It will run with
:meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`.
:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`.
"""
loop.trainer = self
self._fit_loop = loop
@ -2030,7 +2030,7 @@ class Trainer(
"""Attach a custom validation loop to this Trainer.
It will run with
:meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call.
"""
loop.trainer = self