loop customization docs (#9609)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: edenlightning <66261195+edenlightning@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-10-18 11:43:11 +02:00 committed by GitHub
parent 01b304ec57
commit 7a9151637c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 518 additions and 3 deletions

View File

@ -1,6 +1,6 @@
Sequential Data
================
===============
Truncated Backpropagation Through Time
--------------------------------------

View File

@ -67,6 +67,71 @@ Loggers API
test_tube
wandb
Loop API
--------
Base Classes
^^^^^^^^^^^^
.. currentmodule:: pytorch_lightning.loops
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
~base.Loop
~dataloader.dataloader_loop.DataLoaderLoop
Default Loop Implementations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Training
""""""""
.. currentmodule:: pytorch_lightning.loops
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
FitLoop
~epoch.TrainingEpochLoop
~batch.TrainingBatchLoop
~optimization.OptimizerLoop
~optimization.ManualOptimization
Validation and Testing
""""""""""""""""""""""
.. currentmodule:: pytorch_lightning.loops
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
~dataloader.EvaluationLoop
~epoch.EvaluationEpochLoop
Prediction
""""""""""
.. currentmodule:: pytorch_lightning.loops
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
~dataloader.PredictionLoop
~epoch.PredictionEpochLoop
Plugins API
-----------

View File

@ -0,0 +1,403 @@
.. _loop_customization:
Loops
=====
Loops let advanced users swap out the default gradient descent optimization loop at the core of Lightning with a different optimization paradigm.
The Lightning Trainer is built on top of the standard gradient descent optimization loop which works for 90%+ of machine learning use cases:
.. code-block:: python
for i, batch in enumerate(dataloader):
x, y = batch
y_hat = model(x)
loss = loss_function(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
However, some new research use cases such as meta-learning, active learning, recommendation systems, etc., require a different loop structure.
For example here is a simple loop that guides the weight updates with a loss from a special validation split:
.. code-block:: python
for i, batch in enumerate(train_dataloader):
x, y = batch
y_hat = model(x)
loss = loss_function(y_hat, y)
optimizer.zero_grad()
loss.backward()
val_loss = 0
for i, val_batch in enumerate(val_dataloader):
x, y = val_batch
y_hat = model(x)
val_loss += loss_function(y_hat, y)
scale_gradients(model, 1 / val_loss)
optimizer.step()
With Lightning Loops, you can customize to non-standard gradient descent optimizations to get the same loop above:
.. code-block:: python
trainer = Trainer()
trainer.fit_loop.epoch_loop = MyGradientDescentLoop()
Think of this as swapping out the engine in a car!
Understanding the default Trainer loop
--------------------------------------
The Lightning :class:`~pytorch_lightning.trainer.trainer.Trainer` automates the standard optimization loop which every PyTorch user is familiar with:
.. code-block:: python
for i, batch in enumerate(dataloader):
x, y = batch
y_hat = model(x)
loss = loss_function(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
The core research logic is simply shifted to the :class:`~pytorch_lightning.core.lightning.LightningModule`:
.. code-block:: python
for i, batch in enumerate(dataloader):
# x, y = batch moved to training_step
# y_hat = model(x) moved to training_step
# loss = loss_function(y_hat, y) moved to training_step
loss = lightning_module.training_step(batch, i)
# Lighting handles automatically:
optimizer.zero_grad()
loss.backward()
optimizer.step()
Under the hood, the above loop is implemented using the :class:`~pytorch_lightning.loops.base.Loop` API like so:
.. code-block:: python
class DefaultLoop(Loop):
def advance(self, batch, i):
loss = lightning_module.training_step(batch, i)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def run(self, dataloader):
for i, batch in enumerate(dataloader):
self.advance(batch, i)
Defining a loop within a class interface instead of hard-coding a raw Python for/while loop has several benefits:
1. You can have full control over the data flow through loops.
2. You can add new loops and nest as many of them as you want.
3. If needed, the state of a loop can be :ref:`saved and resumed <persisting loop state>`.
4. New hooks can be injected at any point.
.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/epoch-loop-steps.gif
:alt: Animation showing how to convert a standard training loop to a Lightning loop
.. _override default loops:
Overriding the default loops
----------------------------
The fastest way to get started with loops, is to override functionality of an existing loop.
Lightning has 4 main loops it uses: :class:`~pytorch_lightning.loops.fit_loop.FitLoop` for training and validating,
:class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` for testing,
:class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop` for predicting.
For simple changes that don't require a custom loop, you can modify each of these loops.
Each loop has a series of methods that can be modified.
For example with the :class:`~pytorch_lightning.loops.fit_loop.FitLoop`:
.. code-block::
from pytorch_lightning.loops import FitLoop
class MyLoop(FitLoop):
def advance():
...
def on_advance_end(self)
...
def on_run_end(self):
...
A full list with all built-in loops and subloops can be found :ref:`here <loop structure>`.
To add your own modifications to a loop, simply subclass an existing loop class and override what you need.
Here is a simple example how to add a new hook:
.. code-block:: python
from pytorch_lightning.loops import FitLoop
class CustomFitLoop(FitLoop):
def advance(self):
# ... whatever code before
# pass anything you want to the hook
self.trainer.call_hook("my_new_hook", *args, **kwargs)
# ... whatever code after
Now simply attach the correct loop in the trainer directly:
.. code-block:: python
trainer = Trainer(...)
trainer.fit_loop = CustomFitLoop()
# fit() now uses the new FitLoop!
trainer.fit(...)
# the equivalent for validate(), test(), predict()
val_loop = CustomValLoop()
trainer = Trainer()
trainer.validate_loop = val_loop
trainer.validate(model)
Now your code is FULLY flexible and you can still leverage ALL the best parts of Lightning!
.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/replace-fit-loop.gif
:alt: Animation showing how to replace a loop on the Trainer
Creating a new loop from scratch
--------------------------------
You can also go wild and implement a full loop from scratch by sub-classing the :class:`~pytorch_lightning.loops.base.Loop` base class.
You will need to override a minimum of two things:
.. code-block::
from pytorch_lightning.loop import Loop
class MyFancyLoop(Loop):
@property
def done(self):
# provide condition to stop the loop
def advance(self):
# access your dataloader/s in whatever way you want
# do your fancy optimization things
# call the lightning module methods at your leisure
Finally, attach it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
.. code-block:: python
trainer = Trainer(...)
trainer.fit_loop = MyFancyLoop()
# fit() now uses your fancy loop!
trainer.fit(...)
Now you have full control over the Trainer.
But beware: The power of loop customization comes with great responsibility.
We recommend that you familiarize yourself with :ref:`overriding the default loops <override default loops>` first before you start building a new loop from the ground up.
Loop API
--------
Here is the full API of methods available in the Loop base class.
The :class:`~pytorch_lightning.loops.base.Loop` class is the base for all loops in Lighting just like the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base for all models.
It defines a public interface that each loop implementation must follow, the key ones are:
Properties
^^^^^^^^^^
done
~~~~
.. autoattribute:: pytorch_lightning.loops.base.Loop.done
:noindex:
skip (optional)
~~~~~~~~~~~~~~~
.. autoattribute:: pytorch_lightning.loops.base.Loop.skip
:noindex:
Methods
^^^^^^^
reset (optional)
~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.loops.base.Loop.reset
:noindex:
advance
~~~~~~~
.. automethod:: pytorch_lightning.loops.base.Loop.advance
:noindex:
run (optional)
~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.loops.base.Loop.run
:noindex:
Subloops
--------
When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method:
.. code-block:: python
# Step 1: create your loop
my_epoch_loop = MyEpochLoop()
# Step 2: use connect()
trainer.fit_loop.connect(epoch_loop=my_epoch_loop)
# Trainer runs the fit loop with your new epoch loop!
trainer.fit(model)
More about the built-in loops and how they are composed is explained in the next section.
.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif
:alt: Animation showing how to connect a custom subloop
.. _loop structure:
Built-in Loops
--------------
The training loop in Lightning is called *fit loop* and is actually a combination of several loops.
Here is what the structure would look like in plain Python:
.. code-block:: python
# FitLoop
for epoch in range(max_epochs):
# TrainingEpochLoop
for batch_idx, batch in enumerate(train_dataloader):
# TrainingBatchLoop
for split_batch in tbptt_split(batch):
# OptimizerLoop
for optimizer_idx, opt in enumerate(optimizers):
loss = lightning_module.training_step(batch, batch_idx, optimizer_idx)
...
# ValidationEpochLoop
for batch_idx, batch in enumerate(val_dataloader):
lightning_module.validation_step(batch, batch_idx, optimizer_idx)
...
Each of these :code:`for`-loops represents a class implementing the :class:`~pytorch_lightning.loops.base.Loop` interface.
.. list-table:: Trainer entry points and associated loops
:widths: 25 75
:header-rows: 1
* - Built-in loop
- Description
* - :class:`~pytorch_lightning.loops.fit_loop.FitLoop`
- The :class:`~pytorch_lightning.loops.fit_loop.FitLoop` is the top-level loop where training starts.
It simply counts the epochs and iterates from one to the next by calling :code:`TrainingEpochLoop.run()` in its :code:`advance()` method.
* - :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`
- The :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` is the one that iterates over the dataloader that the user returns in their :meth:`~pytorch_lightning.core.lightning.LightningModule.train_dataloader` method.
Its main responsibilities are calling the :code:`*_epoch_start` and :code:`*_epoch_end` hooks, accumulating outputs if the user request them in one of these hooks, and running validation at the requested interval.
The validation is carried out by yet another loop, :class:`~pytorch_lightning.loops.epoch.validation_epoch_loop.ValidationEpochLoop`.
In the :code:`run()` method, the training epoch loop could in theory simply call the :code:`LightningModule.training_step` already and perform the optimization.
However, Lightning has built-in support for automatic optimization with multiple optimizers and on top of that also supports :doc:`truncated back-propagation through time <../advanced/sequences>`.
For this reason there are actually two more loops nested under :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`.
* - :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop`
- The responsibility of the :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` is to split a batch given by the :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` along the time-dimension and iterate over the list of splits.
It also keeps track of the hidden state *hiddens* returned by the training step.
By default, when truncated back-propagation through time (TBPTT) is turned off, this loop does not do anything except redirect the call to the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`.
Read more about :doc:`TBPTT <../advanced/sequences>`.
* - :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`
- The :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` iterates over one or multiple optimizers and for each one it calls the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method with the batch, the current batch index and the optimizer index if multiple optimizers are requested.
It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step).
* - :class:`~pytorch_lightning.loops.optimization.manual_loop.ManualOptimization`
- Substitutes the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` in case of :ref:`manual_optimization` and implements the manual optimization step.
Available Loops in Lightning Flash
----------------------------------
`Active Learning <https://en.wikipedia.org/wiki/Active_learning_(machine_learning)>`__ is a machine learning practice in which the user interacts with the learner in order to provide new labels when required.
You can find a real use case in `Lightning Flash <https://github.com/PyTorchLightning/lightning-flash>`_.
Flash implements the :code:`ActiveLearningLoop` that you can use together with the :code:`ActiveLearningDataModule` to label new data on the fly.
To run the following demo, install Flash and `BaaL <https://github.com/ElementAI/baal>`__ first:
.. code-block:: bash
pip install lightning-flash baal
.. code-block:: python
import torch
import flash
from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
# Implement the research use-case where we mask labels from labelled dataset.
datamodule = ActiveLearningDataModule(
ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
val_split=0.1,
)
# 2. Build the task
head = torch.nn.Sequential(
torch.nn.Dropout(p=0.1),
torch.nn.Linear(512, datamodule.num_classes),
)
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities())
# 3.1 Create the trainer
trainer = flash.Trainer(max_epochs=3)
# 3.2 Create the active learning loop and connect it to the trainer
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop
# 3.3 Finetune
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict what's on a few images! ants or bees?
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Here is the `runnable example <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py#L31>`_.
Advanced Topics and Examples
----------------------------
Next: :doc:`Advanced loop features and examples <../extensions/loops_advanced>`

View File

@ -0,0 +1,41 @@
:orphan:
Loops (Advanced)
================
.. _persisting loop state:
Persisting the state of loops
-----------------------------
.. note::
This is an experimental feature and is not activated by default.
Set the environment variable `PL_FAULT_TOLERANT_TRAINING = 1` to enable saving the progress of loops.
Read more about :doc:`fault-tolerant training <../advanced/fault_tolerant_training>`.
A powerful property of the class-based loop interface is that it can own an internal state.
Loop instances can save their state to the checkpoint through corresponding hooks and if implemented accordingly, resume the state of exectuion at the appropriate place.
This design is particularly interesting for fault-tolerant training which is an experimental feature released in Lightning v1.5.
The two hooks :class:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and :class:`~pytorch_lightning.loops.base.Loop.on_load_checkpoint` function very similarly to how LightningModules and Callbacks save and load state.
.. code-block:: python
def on_save_checkpoint(self):
state_dict["iteration"] = self.iteration
return state_dict
def on_load_checkpoint(self, state_dict):
self.iteration = state_dict["iteration"]
When the Trainer is restarting from a checkpoint (e.g., through :code:`Trainer(resume_from_checkpoint=...)`), the loop exposes a boolean attribute :attr:`~pytorch_lightning.loops.base.Loop.restarting`.
Based around the value of this variable, the user can write the loop in such a way that it can restart from an arbitrary point given the state loaded from the checkpoint.
For example, the implementation of the :meth:`~pytorch_lightning.loops.base.Loop.reset` method could look like this given our previous example:
.. code-block:: python
def reset(self):
if not self.restarting:
self.iteration = 0

View File

@ -85,7 +85,7 @@ PyTorch Lightning
extensions/logging
extensions/metrics
extensions/plugins
extensions/loops
.. toctree::
:maxdepth: 1

View File

@ -272,6 +272,11 @@ Turn off automatic optimization and you control the train loop!
self.manual_backward(loss_b)
opt_b.step()
Loop customization
==================
If you need even more flexibility, you can fully customize the training loop to its core.
Learn more about loops :doc:`here <../extensions/loops>`.
Predict or Deploy
=================

View File

@ -35,7 +35,7 @@ class Loop(ABC, Generic[T]):
This class implements the following loop structure:
.. codeblock:: python
.. code-block:: python
on_run_start()

View File

@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401