From 7a9151637cea5a4f1ac64f13cbd1e16b28332a53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 11:43:11 +0200 Subject: [PATCH] loop customization docs (#9609) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos MocholĂ­ Co-authored-by: thomas chaton Co-authored-by: edenlightning <66261195+edenlightning@users.noreply.github.com> --- docs/source/advanced/sequences.rst | 2 +- docs/source/api_references.rst | 65 +++ docs/source/extensions/loops.rst | 403 ++++++++++++++++++ docs/source/extensions/loops_advanced.rst | 41 ++ docs/source/index.rst | 2 +- docs/source/starter/new-project.rst | 5 + pytorch_lightning/loops/base.py | 2 +- .../loops/optimization/__init__.py | 1 + 8 files changed, 518 insertions(+), 3 deletions(-) create mode 100644 docs/source/extensions/loops.rst create mode 100644 docs/source/extensions/loops_advanced.rst diff --git a/docs/source/advanced/sequences.rst b/docs/source/advanced/sequences.rst index 8e50de4993..2d8d770cbb 100644 --- a/docs/source/advanced/sequences.rst +++ b/docs/source/advanced/sequences.rst @@ -1,6 +1,6 @@ Sequential Data -================ +=============== Truncated Backpropagation Through Time -------------------------------------- diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index df70b2b0a3..7bc4d8b460 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -67,6 +67,71 @@ Loggers API test_tube wandb +Loop API +-------- + +Base Classes +^^^^^^^^^^^^ + +.. currentmodule:: pytorch_lightning.loops + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + ~base.Loop + ~dataloader.dataloader_loop.DataLoaderLoop + + +Default Loop Implementations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Training +"""""""" + +.. currentmodule:: pytorch_lightning.loops + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + FitLoop + ~epoch.TrainingEpochLoop + ~batch.TrainingBatchLoop + ~optimization.OptimizerLoop + ~optimization.ManualOptimization + + +Validation and Testing +"""""""""""""""""""""" + +.. currentmodule:: pytorch_lightning.loops + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + ~dataloader.EvaluationLoop + ~epoch.EvaluationEpochLoop + + +Prediction +"""""""""" + +.. currentmodule:: pytorch_lightning.loops + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + ~dataloader.PredictionLoop + ~epoch.PredictionEpochLoop + + Plugins API ----------- diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst new file mode 100644 index 0000000000..b83a64d2f6 --- /dev/null +++ b/docs/source/extensions/loops.rst @@ -0,0 +1,403 @@ +.. _loop_customization: + +Loops +===== + +Loops let advanced users swap out the default gradient descent optimization loop at the core of Lightning with a different optimization paradigm. + +The Lightning Trainer is built on top of the standard gradient descent optimization loop which works for 90%+ of machine learning use cases: + +.. code-block:: python + + for i, batch in enumerate(dataloader): + x, y = batch + y_hat = model(x) + loss = loss_function(y_hat, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +However, some new research use cases such as meta-learning, active learning, recommendation systems, etc., require a different loop structure. +For example here is a simple loop that guides the weight updates with a loss from a special validation split: + +.. code-block:: python + + for i, batch in enumerate(train_dataloader): + x, y = batch + y_hat = model(x) + loss = loss_function(y_hat, y) + optimizer.zero_grad() + loss.backward() + + val_loss = 0 + for i, val_batch in enumerate(val_dataloader): + x, y = val_batch + y_hat = model(x) + val_loss += loss_function(y_hat, y) + + scale_gradients(model, 1 / val_loss) + optimizer.step() + + +With Lightning Loops, you can customize to non-standard gradient descent optimizations to get the same loop above: + +.. code-block:: python + + trainer = Trainer() + trainer.fit_loop.epoch_loop = MyGradientDescentLoop() + +Think of this as swapping out the engine in a car! + +Understanding the default Trainer loop +-------------------------------------- + +The Lightning :class:`~pytorch_lightning.trainer.trainer.Trainer` automates the standard optimization loop which every PyTorch user is familiar with: + +.. code-block:: python + + for i, batch in enumerate(dataloader): + x, y = batch + y_hat = model(x) + loss = loss_function(y_hat, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +The core research logic is simply shifted to the :class:`~pytorch_lightning.core.lightning.LightningModule`: + +.. code-block:: python + + for i, batch in enumerate(dataloader): + # x, y = batch moved to training_step + # y_hat = model(x) moved to training_step + # loss = loss_function(y_hat, y) moved to training_step + loss = lightning_module.training_step(batch, i) + + # Lighting handles automatically: + optimizer.zero_grad() + loss.backward() + optimizer.step() + +Under the hood, the above loop is implemented using the :class:`~pytorch_lightning.loops.base.Loop` API like so: + +.. code-block:: python + + class DefaultLoop(Loop): + def advance(self, batch, i): + loss = lightning_module.training_step(batch, i) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def run(self, dataloader): + for i, batch in enumerate(dataloader): + self.advance(batch, i) + +Defining a loop within a class interface instead of hard-coding a raw Python for/while loop has several benefits: + +1. You can have full control over the data flow through loops. +2. You can add new loops and nest as many of them as you want. +3. If needed, the state of a loop can be :ref:`saved and resumed `. +4. New hooks can be injected at any point. + +.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/epoch-loop-steps.gif + :alt: Animation showing how to convert a standard training loop to a Lightning loop + + +.. _override default loops: + +Overriding the default loops +---------------------------- + +The fastest way to get started with loops, is to override functionality of an existing loop. +Lightning has 4 main loops it uses: :class:`~pytorch_lightning.loops.fit_loop.FitLoop` for training and validating, +:class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` for testing, +:class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop` for predicting. + +For simple changes that don't require a custom loop, you can modify each of these loops. + +Each loop has a series of methods that can be modified. +For example with the :class:`~pytorch_lightning.loops.fit_loop.FitLoop`: + +.. code-block:: + + from pytorch_lightning.loops import FitLoop + + class MyLoop(FitLoop): + + def advance(): + ... + + def on_advance_end(self) + ... + + def on_run_end(self): + ... + +A full list with all built-in loops and subloops can be found :ref:`here `. + +To add your own modifications to a loop, simply subclass an existing loop class and override what you need. +Here is a simple example how to add a new hook: + +.. code-block:: python + + from pytorch_lightning.loops import FitLoop + + + class CustomFitLoop(FitLoop): + def advance(self): + # ... whatever code before + + # pass anything you want to the hook + self.trainer.call_hook("my_new_hook", *args, **kwargs) + + # ... whatever code after + +Now simply attach the correct loop in the trainer directly: + +.. code-block:: python + + trainer = Trainer(...) + trainer.fit_loop = CustomFitLoop() + + # fit() now uses the new FitLoop! + trainer.fit(...) + + # the equivalent for validate(), test(), predict() + val_loop = CustomValLoop() + trainer = Trainer() + trainer.validate_loop = val_loop + trainer.validate(model) + +Now your code is FULLY flexible and you can still leverage ALL the best parts of Lightning! + +.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/replace-fit-loop.gif + :alt: Animation showing how to replace a loop on the Trainer + +Creating a new loop from scratch +-------------------------------- + +You can also go wild and implement a full loop from scratch by sub-classing the :class:`~pytorch_lightning.loops.base.Loop` base class. +You will need to override a minimum of two things: + +.. code-block:: + + from pytorch_lightning.loop import Loop + + class MyFancyLoop(Loop): + + @property + def done(self): + # provide condition to stop the loop + + def advance(self): + # access your dataloader/s in whatever way you want + # do your fancy optimization things + # call the lightning module methods at your leisure + +Finally, attach it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`: + +.. code-block:: python + + trainer = Trainer(...) + trainer.fit_loop = MyFancyLoop() + + # fit() now uses your fancy loop! + trainer.fit(...) + +Now you have full control over the Trainer. +But beware: The power of loop customization comes with great responsibility. +We recommend that you familiarize yourself with :ref:`overriding the default loops ` first before you start building a new loop from the ground up. + +Loop API +-------- +Here is the full API of methods available in the Loop base class. + +The :class:`~pytorch_lightning.loops.base.Loop` class is the base for all loops in Lighting just like the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base for all models. +It defines a public interface that each loop implementation must follow, the key ones are: + +Properties +^^^^^^^^^^ + +done +~~~~ + +.. autoattribute:: pytorch_lightning.loops.base.Loop.done + :noindex: + +skip (optional) +~~~~~~~~~~~~~~~ + +.. autoattribute:: pytorch_lightning.loops.base.Loop.skip + :noindex: + +Methods +^^^^^^^ + +reset (optional) +~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.loops.base.Loop.reset + :noindex: + +advance +~~~~~~~ + +.. automethod:: pytorch_lightning.loops.base.Loop.advance + :noindex: + +run (optional) +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.loops.base.Loop.run + :noindex: + + +Subloops +-------- + +When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method: + +.. code-block:: python + + # Step 1: create your loop + my_epoch_loop = MyEpochLoop() + + # Step 2: use connect() + trainer.fit_loop.connect(epoch_loop=my_epoch_loop) + + # Trainer runs the fit loop with your new epoch loop! + trainer.fit(model) + +More about the built-in loops and how they are composed is explained in the next section. + +.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif + :alt: Animation showing how to connect a custom subloop + +.. _loop structure: + +Built-in Loops +-------------- + +The training loop in Lightning is called *fit loop* and is actually a combination of several loops. +Here is what the structure would look like in plain Python: + +.. code-block:: python + + # FitLoop + for epoch in range(max_epochs): + + # TrainingEpochLoop + for batch_idx, batch in enumerate(train_dataloader): + + # TrainingBatchLoop + for split_batch in tbptt_split(batch): + + # OptimizerLoop + for optimizer_idx, opt in enumerate(optimizers): + + loss = lightning_module.training_step(batch, batch_idx, optimizer_idx) + ... + + # ValidationEpochLoop + for batch_idx, batch in enumerate(val_dataloader): + lightning_module.validation_step(batch, batch_idx, optimizer_idx) + ... + + +Each of these :code:`for`-loops represents a class implementing the :class:`~pytorch_lightning.loops.base.Loop` interface. + + +.. list-table:: Trainer entry points and associated loops + :widths: 25 75 + :header-rows: 1 + + * - Built-in loop + - Description + * - :class:`~pytorch_lightning.loops.fit_loop.FitLoop` + - The :class:`~pytorch_lightning.loops.fit_loop.FitLoop` is the top-level loop where training starts. + It simply counts the epochs and iterates from one to the next by calling :code:`TrainingEpochLoop.run()` in its :code:`advance()` method. + * - :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` + - The :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` is the one that iterates over the dataloader that the user returns in their :meth:`~pytorch_lightning.core.lightning.LightningModule.train_dataloader` method. + Its main responsibilities are calling the :code:`*_epoch_start` and :code:`*_epoch_end` hooks, accumulating outputs if the user request them in one of these hooks, and running validation at the requested interval. + The validation is carried out by yet another loop, :class:`~pytorch_lightning.loops.epoch.validation_epoch_loop.ValidationEpochLoop`. + + In the :code:`run()` method, the training epoch loop could in theory simply call the :code:`LightningModule.training_step` already and perform the optimization. + However, Lightning has built-in support for automatic optimization with multiple optimizers and on top of that also supports :doc:`truncated back-propagation through time <../advanced/sequences>`. + For this reason there are actually two more loops nested under :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`. + * - :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` + - The responsibility of the :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` is to split a batch given by the :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` along the time-dimension and iterate over the list of splits. + It also keeps track of the hidden state *hiddens* returned by the training step. + By default, when truncated back-propagation through time (TBPTT) is turned off, this loop does not do anything except redirect the call to the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`. + Read more about :doc:`TBPTT <../advanced/sequences>`. + * - :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` + - The :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` iterates over one or multiple optimizers and for each one it calls the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method with the batch, the current batch index and the optimizer index if multiple optimizers are requested. + It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step). + * - :class:`~pytorch_lightning.loops.optimization.manual_loop.ManualOptimization` + - Substitutes the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` in case of :ref:`manual_optimization` and implements the manual optimization step. + + +Available Loops in Lightning Flash +---------------------------------- + +`Active Learning `__ is a machine learning practice in which the user interacts with the learner in order to provide new labels when required. + +You can find a real use case in `Lightning Flash `_. + +Flash implements the :code:`ActiveLearningLoop` that you can use together with the :code:`ActiveLearningDataModule` to label new data on the fly. +To run the following demo, install Flash and `BaaL `__ first: + +.. code-block:: bash + + pip install lightning-flash baal + +.. code-block:: python + + import torch + + import flash + from flash.core.classification import Probabilities + from flash.core.data.utils import download_data + from flash.image import ImageClassificationData, ImageClassifier + from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop + + # 1. Create the DataModule + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") + + # Implement the research use-case where we mask labels from labelled dataset. + datamodule = ActiveLearningDataModule( + ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2), + val_split=0.1, + ) + + # 2. Build the task + head = torch.nn.Sequential( + torch.nn.Dropout(p=0.1), + torch.nn.Linear(512, datamodule.num_classes), + ) + model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities()) + + # 3.1 Create the trainer + trainer = flash.Trainer(max_epochs=3) + + # 3.2 Create the active learning loop and connect it to the trainer + active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1) + active_learning_loop.connect(trainer.fit_loop) + trainer.fit_loop = active_learning_loop + + # 3.3 Finetune + trainer.finetune(model, datamodule=datamodule, strategy="freeze") + + # 4. Predict what's on a few images! ants or bees? + predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") + print(predictions) + + # 5. Save the model! + trainer.save_checkpoint("image_classification_model.pt") + +Here is the `runnable example `_ and the `code for the active learning loop `_. + +Advanced Topics and Examples +---------------------------- + +Next: :doc:`Advanced loop features and examples <../extensions/loops_advanced>` diff --git a/docs/source/extensions/loops_advanced.rst b/docs/source/extensions/loops_advanced.rst new file mode 100644 index 0000000000..6cf8ceb72b --- /dev/null +++ b/docs/source/extensions/loops_advanced.rst @@ -0,0 +1,41 @@ +:orphan: + +Loops (Advanced) +================ + +.. _persisting loop state: + +Persisting the state of loops +----------------------------- + +.. note:: + + This is an experimental feature and is not activated by default. + Set the environment variable `PL_FAULT_TOLERANT_TRAINING = 1` to enable saving the progress of loops. + Read more about :doc:`fault-tolerant training <../advanced/fault_tolerant_training>`. + +A powerful property of the class-based loop interface is that it can own an internal state. +Loop instances can save their state to the checkpoint through corresponding hooks and if implemented accordingly, resume the state of exectuion at the appropriate place. +This design is particularly interesting for fault-tolerant training which is an experimental feature released in Lightning v1.5. + +The two hooks :class:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and :class:`~pytorch_lightning.loops.base.Loop.on_load_checkpoint` function very similarly to how LightningModules and Callbacks save and load state. + +.. code-block:: python + + def on_save_checkpoint(self): + state_dict["iteration"] = self.iteration + return state_dict + + + def on_load_checkpoint(self, state_dict): + self.iteration = state_dict["iteration"] + +When the Trainer is restarting from a checkpoint (e.g., through :code:`Trainer(resume_from_checkpoint=...)`), the loop exposes a boolean attribute :attr:`~pytorch_lightning.loops.base.Loop.restarting`. +Based around the value of this variable, the user can write the loop in such a way that it can restart from an arbitrary point given the state loaded from the checkpoint. +For example, the implementation of the :meth:`~pytorch_lightning.loops.base.Loop.reset` method could look like this given our previous example: + +.. code-block:: python + + def reset(self): + if not self.restarting: + self.iteration = 0 diff --git a/docs/source/index.rst b/docs/source/index.rst index f0eb5c05af..ea3e606d72 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -85,7 +85,7 @@ PyTorch Lightning extensions/logging extensions/metrics extensions/plugins - + extensions/loops .. toctree:: :maxdepth: 1 diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 88213637d4..0626aa09db 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -272,6 +272,11 @@ Turn off automatic optimization and you control the train loop! self.manual_backward(loss_b) opt_b.step() +Loop customization +================== + +If you need even more flexibility, you can fully customize the training loop to its core. +Learn more about loops :doc:`here <../extensions/loops>`. Predict or Deploy ================= diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 1a19c753b0..ef53df92c6 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -35,7 +35,7 @@ class Loop(ABC, Generic[T]): This class implements the following loop structure: - .. codeblock:: python + .. code-block:: python on_run_start() diff --git a/pytorch_lightning/loops/optimization/__init__.py b/pytorch_lightning/loops/optimization/__init__.py index 17e96c49d3..07249b6a13 100644 --- a/pytorch_lightning/loops/optimization/__init__.py +++ b/pytorch_lightning/loops/optimization/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401 from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401