Update child modules docs (#11198)

This commit is contained in:
Rohit Gupta 2022-01-11 14:17:25 +05:30 committed by GitHub
parent 34c62da37d
commit 9092cf3a35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 214 additions and 248 deletions

View File

@ -2,11 +2,13 @@
from pytorch_lightning.core.lightning import LightningModule
#################
Transfer Learning
-----------------
#################
***********************
Using Pretrained Models
^^^^^^^^^^^^^^^^^^^^^^^
***********************
Sometimes we want to use a LightningModule as a pretrained model. This is fine because
a LightningModule is just a `torch.nn.Module`!
@ -44,8 +46,8 @@ Let's use the `AutoEncoder` as a feature extractor in a separate model.
We used our pretrained Autoencoder (a LightningModule) for transfer learning!
Example: Imagenet (computer Vision)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Example: Imagenet (Computer Vision)
===================================
.. testcode::
:skipif: not _TORCHVISION_AVAILABLE
@ -96,7 +98,8 @@ We used a pretrained model on imagenet, finetuned on CIFAR-10 to predict on CIFA
In the non-academic world we would finetune on a tiny dataset you have and predict on your dataset.
Example: BERT (NLP)
^^^^^^^^^^^^^^^^^^^
===================
Lightning is completely agnostic to what's used for transfer learning so long
as it is a `torch.nn.Module` subclass.

View File

@ -1,59 +1,41 @@
.. testsetup:: *
import torch
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
class LitMNIST(LightningModule):
def __init__(self):
super().__init__()
def train_dataloader():
pass
def val_dataloader():
pass
def test_dataloader():
pass
Child Modules
-------------
Research projects tend to test different approaches to the same dataset.
This is very easy to do in Lightning with inheritance.
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
We are extending our Autoencoder from the `LitMNIST`-module which already defines all the dataloading.
The only things that change in the `Autoencoder` model are the init, forward, training, validation and test step.
For example, imagine we now want to train an ``AutoEncoder`` to use as a feature extractor for images.
The only things that change in the ``LitAutoEncoder`` model are the init, forward, training, validation and test step.
.. testcode::
.. code-block:: python
class Encoder(torch.nn.Module):
pass
...
class Decoder(torch.nn.Module):
pass
...
class AutoEncoder(LitMNIST):
class AutoEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.metric = MSE()
def forward(self, x):
return self.encoder(x)
return self.decoder(self.encoder(x))
class LitAutoEncoder(LightningModule):
def __init__(self, auto_encoder):
super().__init__()
self.auto_encoder = auto_encoder
self.metric = torch.nn.MSELoss()
def forward(self, x):
return self.auto_encoder.encoder(x)
def training_step(self, batch, batch_idx):
x, _ = batch
representation = self.encoder(x)
x_hat = self.decoder(representation)
x_hat = self.auto_encoder(x)
loss = self.metric(x, x_hat)
return loss
@ -65,25 +47,24 @@ The only things that change in the `Autoencoder` model are the init, forward, tr
def _shared_eval(self, batch, batch_idx, prefix):
x, _ = batch
representation = self.encoder(x)
x_hat = self.decoder(representation)
x_hat = self.auto_encoder(x)
loss = self.metric(x, x_hat)
self.log(f"{prefix}_loss", loss)
and we can train this using the same trainer
and we can train this using the ``Trainer``:
.. code-block:: python
autoencoder = AutoEncoder()
auto_encoder = AutoEncoder()
lightning_module = LitAutoEncoder(auto_encoder)
trainer = Trainer()
trainer.fit(autoencoder)
trainer.fit(lightning_module, train_dataloader, val_dataloader)
And remember that the forward method should define the practical use of a LightningModule.
In this case, we want to use the `AutoEncoder` to extract image representations
And remember that the forward method should define the practical use of a :class:`~pytorch_lightning.core.lightning.LightningModule`.
In this case, we want to use the ``LitAutoEncoder`` to extract image representations:
.. code-block:: python
some_images = torch.Tensor(32, 1, 28, 28)
representations = autoencoder(some_images)
representations = lightning_module(some_images)

View File

@ -3,15 +3,17 @@
.. _lightning_module:
###############
LightningModule
===============
###############
A :class:`~LightningModule` organizes your PyTorch code into 6 sections:
- Computations (init).
- Train loop (training_step)
- Validation loop (validation_step)
- Test loop (test_step)
- Prediction loop (predict_step)
- Train Loop (training_step)
- Validation Loop (validation_step)
- Test Loop (test_step)
- Prediction Loop (predict_step)
- Optimizers and LR Schedulers (configure_optimizers)
|
@ -85,8 +87,9 @@ Thus, to use Lightning, you just need to organize your code which takes about 30
------------
Minimal Example
---------------
***************
Starter Example
***************
Here are the only required methods.
@ -147,11 +150,13 @@ The LightningModule has many convenience methods, but the core ones you need to
----------
********
Training
--------
********
Training Loop
^^^^^^^^^^^^^
=============
To activate the training loop, override the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method.
.. code-block:: python
@ -171,15 +176,14 @@ Under the hood, Lightning does the following (pseudocode):
.. code-block:: python
# put model in train mode
# put model in train mode and enable gradient calculation
model.train()
torch.set_grad_enabled(True)
losses = []
for batch in train_dataloader:
# forward
loss = training_step(batch)
losses.append(loss.detach())
outs = []
for batch_idx, batch in enumerate(train_dataloader):
loss = training_step(batch, batch_idx)
outs.append(loss.detach())
# clear gradients
optimizer.zero_grad()
@ -191,8 +195,9 @@ Under the hood, Lightning does the following (pseudocode):
optimizer.step()
Training Epoch-Level Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Train Epoch-level Metrics
=========================
If you want to calculate epoch-level metrics and log them, use :meth:`~pytorch_lightning.core.lightning.LightningModule.log`.
.. code-block:: python
@ -213,10 +218,10 @@ requested metrics across a complete epoch and devices. Here's the pseudocode of
.. code-block:: python
outs = []
for batch in train_dataloader:
for batch_idx, batch in enumerate(train_dataloader):
# forward
out = training_step(val_batch)
outs.append(out)
loss = training_step(batch, batch_idx)
outs.append(loss)
# clear gradients
optimizer.zero_grad()
@ -227,11 +232,12 @@ requested metrics across a complete epoch and devices. Here's the pseudocode of
# update parameters
optimizer.step()
epoch_metric = torch.mean(torch.stack([x["train_loss"] for x in outs]))
epoch_metric = torch.mean(torch.stack([x for x in outs]))
Train Epoch-Level Operations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`.
Train Epoch-level Operations
============================
If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
override the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end` method.
.. code-block:: python
@ -253,10 +259,10 @@ The matching pseudocode is:
.. code-block:: python
outs = []
for batch in train_dataloader:
for batch_idx, batch in enumerate(train_dataloader):
# forward
out = training_step(val_batch)
outs.append(out)
loss = training_step(batch, batch_idx)
outs.append(loss)
# clear gradients
optimizer.zero_grad()
@ -270,7 +276,8 @@ The matching pseudocode is:
training_epoch_end(outs)
Training with DataParallel
~~~~~~~~~~~~~~~~~~~~~~~~~~
==========================
When training using a ``strategy`` that splits data from each batch across GPUs, sometimes you might
need to aggregate them on the main GPU for processing (DP, or DDP2).
@ -309,12 +316,12 @@ Here is the Lightning training pseudo-code for DP:
.. code-block:: python
outs = []
for train_batch in train_dataloader:
for batch_idx, train_batch in enumerate(train_dataloader):
batches = split_batch(train_batch)
dp_outs = []
for sub_batch in batches:
# 1
dp_out = training_step(sub_batch)
dp_out = training_step(sub_batch, batch_idx)
dp_outs.append(dp_out)
# 2
@ -327,8 +334,13 @@ Here is the Lightning training pseudo-code for DP:
------------------
**********
Validation
**********
Validation Loop
^^^^^^^^^^^^^^^
===============
To activate the validation loop while training, override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` method.
.. code-block:: python
@ -345,8 +357,8 @@ Under the hood, Lightning does the following (pseudocode):
.. code-block:: python
# ...
for batch in train_dataloader:
loss = model.training_step()
for batch_idx, batch in enumerate(train_dataloader):
loss = model.training_step(batch, batch_idx)
loss.backward()
# ...
@ -356,8 +368,8 @@ Under the hood, Lightning does the following (pseudocode):
model.eval()
# ----------------- VAL LOOP ---------------
for val_batch in model.val_dataloader:
val_out = model.validation_step(val_batch)
for val_batch_idx, val_batch in enumerate(val_dataloader):
val_out = model.validation_step(val_batch, val_batch_idx)
# ----------------- VAL LOOP ---------------
# enable grads + batchnorm + dropout
@ -373,8 +385,18 @@ and calling :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`.
trainer = Trainer()
trainer.validate(model)
Validation Epoch-Level Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. note::
It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a
multi-device setting, samples could occur duplicated when :class:`~torch.utils.data.distributed.DistributedSampler`
is used, for eg. with ``strategy="ddp"``. It replicates some samples on some devices to make sure all devices have
same batch size in case of uneven inputs.
Validation Epoch-level Metrics
==============================
If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`,
override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` method.
@ -393,7 +415,8 @@ override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation
...
Validating with DataParallel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
============================
When training using a ``strategy`` that splits data from each batch across GPUs, sometimes you might
need to aggregate them on the main GPU for processing (DP, or DDP2).
@ -450,8 +473,13 @@ Here is the Lightning validation pseudo-code for DP:
----------------
*******
Testing
*******
Test Loop
^^^^^^^^^
=========
The process for enabling a test loop is the same as the process for enabling a validation loop. Please refer to
the section above for details. For this you need to override the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` method.
@ -482,118 +510,79 @@ There are two ways to call ``test()``:
trainer = Trainer()
trainer.test(model, dataloaders=test_dataloader)
.. note::
It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a
multi-device setting, samples could occur duplicated when :class:`~torch.utils.data.distributed.DistributedSampler`
is used, for eg. with ``strategy="ddp"``. It replicates some samples on some devices to make sure all devices have
same batch size in case of uneven inputs.
----------
Inference (Prediction Loop)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
For research, LightningModules are best structured as systems.
*********
Inference
*********
Prediction Loop
===============
By default, the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method runs the
:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method. In order to customize this behaviour,
simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method.
For the example let's override ``predict_step`` and try out `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_:
.. code-block:: python
import pytorch_lightning as pl
import torch
from torch import nn
class Autoencoder(pl.LightningModule):
def __init__(self, latent_dim=2):
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
def training_step(self, batch, batch_idx):
x, _ = batch
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# encode
x = x.view(x.size(0), -1)
z = self.encoder(x)
# take average of `self.mc_iteration` iterations
pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
return pred
# decode
recons = self.decoder(z)
# reconstruction
reconstruction_loss = nn.functional.mse_loss(recons, x)
return reconstruction_loss
def validation_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
recons = self.decoder(z)
reconstruction_loss = nn.functional.mse_loss(recons, x)
self.log("val_reconstruction", reconstruction_loss)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, _ = batch
# encode
# for predictions, we could return the embedding or the reconstruction or both based on our need.
x = x.view(x.size(0), -1)
return self.encoder(x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0002)
Which can be trained like this:
Under the hood, Lightning does the following (pseudocode):
.. code-block:: python
autoencoder = Autoencoder()
trainer = pl.Trainer(gpus=1)
trainer.fit(autoencoder, train_dataloader, val_dataloader)
# disable grads + batchnorm + dropout
torch.set_grad_enabled(False)
model.eval()
all_preds = []
This simple model generates examples that look like this (the encoders and decoders are too weak)
for batch_idx, batch in enumerate(predict_dataloader):
pred = model.predict_step(batch, batch_idx)
all_preds.append(pred)
.. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/ae_docs.png
:width: 300
The methods above are part of the LightningModule interface:
- training_step
- validation_step
- test_step
- predict_step
- configure_optimizers
Note that in this case, the train loop and val loop are exactly the same. We can, of course, reuse this code.
There are two ways to call ``predict()``:
.. code-block:: python
class Autoencoder(pl.LightningModule):
def __init__(self, latent_dim=2):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
# call after training
trainer = Trainer()
trainer.fit(model)
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
# automatically auto-loads the best weights from the previous run
predictions = trainer.predict(dataloaders=predict_dataloader)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("val_loss", loss)
def shared_step(self, batch):
x, _ = batch
# encode
x = x.view(x.size(0), -1)
z = self.encoder(x)
# decode
recons = self.decoder(z)
# loss
return nn.functional.mse_loss(recons, x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0002)
We create a new method called ``shared_step`` that all loops can use. This method name is arbitrary and NOT reserved.
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
predictions = trainer.predict(model, dataloaders=test_dataloader)
Inference in Research
^^^^^^^^^^^^^^^^^^^^^
=====================
If you want to perform inference with the system, you can add a ``forward`` method to the LightningModule.
.. note:: When using forward, you are responsible to call :func:`~torch.nn.Module.eval` and use the :func:`~torch.no_grad` context manager.
@ -644,7 +633,8 @@ In the case where you want to scale your inference, you should be using
trainer.predict(model, data_module)
Inference in Production
^^^^^^^^^^^^^^^^^^^^^^^
=======================
For cases like production, you might want to iterate different models inside a LightningModule.
.. code-block:: python
@ -715,8 +705,8 @@ Tasks can be arbitrarily complex such as implementing GAN training, self-supervi
When used like this, the model can be separated from the Task and thus used in production without needing to keep it in
a ``LightningModule``.
- You can export to onnx using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_onnx`.
- Or trace using Jit using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript`.
- You can export to `ONNX <https://pytorch.org/docs/stable/onnx.html>`_ using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_onnx`.
- Or trace using `TorchScript <https://pytorch.org/docs/stable/jit.html>`_ using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript`.
- Or run in the Python runtime.
.. code-block:: python
@ -731,13 +721,25 @@ a ``LightningModule``.
with torch.no_grad():
y_hat = model(x)
-----------
*************
Child Modules
*************
.. include:: ../common/child_modules.rst
-----------
*******************
LightningModule API
-------------------
*******************
Methods
^^^^^^^
=======
all_gather
~~~~~~~~~~
@ -900,65 +902,61 @@ validation_epoch_end
.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_epoch_end
:noindex:
------------
-----------
Properties
^^^^^^^^^^
These are properties available in a LightningModule.
==========
-----------
These are properties available in a LightningModule.
current_epoch
~~~~~~~~~~~~~
The current epoch
.. code-block:: python
def training_step(self):
def training_step(self, batch, batch_idx):
if self.current_epoch == 0:
...
-------------
device
~~~~~~
The device the module is on. Use it to keep your code device agnostic.
.. code-block:: python
def training_step(self):
def training_step(self, batch, batch_idx):
z = torch.rand(2, 3, device=self.device)
-------------
global_rank
~~~~~~~~~~~
The ``global_rank`` is the index of the current process across all nodes and devices.
Lightning will perform some operations such as logging, weight checkpointing only when ``global_rank=0``. You
usually do not need to use this property, but it is useful to know how to access it if needed.
.. code-block:: python
def training_step(self):
def training_step(self, batch, batch_idx):
if self.global_rank == 0:
# do something only once across all the nodes
self.log("global_step", self.trainer.global_step)
-------------
global_step
~~~~~~~~~~~
The current step (does not reset each epoch)
.. code-block:: python
def training_step(self):
def training_step(self, batch, batch_idx):
self.logger.experiment.log_image(..., step=self.global_step)
-------------
hparams
~~~~~~~
The arguments passed through ``LightningModule.__init__()`` and saved by calling
:meth:`~pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin.save_hyperparameters` could be accessed by the ``hparams`` attribute.
@ -971,70 +969,64 @@ The arguments passed through ``LightningModule.__init__()`` and saved by calling
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.learning_rate)
--------------
logger
~~~~~~
The current logger being used (tensorboard or other supported logger)
.. code-block:: python
def training_step(self):
def training_step(self, batch, batch_idx):
# the generic logger (same no matter if tensorboard or other supported logger)
self.logger
# the particular logger
tensorboard_logger = self.logger.experiment
--------------
local_rank
~~~~~~~~~~~
The ``global_rank`` is the index of the current process across all the devices for the current node.
You usually do not need to use this property, but it is useful to know how to access it if needed.
For example, if using 10 machines (or nodes), the GPU at index 0 on each machine has local_rank = 0.
.. code-block:: python
def training_step(self):
def training_step(self, batch, batch_idx):
if self.global_rank == 0:
# do something only once across each node
self.log("global_step", self.trainer.global_step)
-----------
precision
~~~~~~~~~
The type of precision used:
.. code-block:: python
def training_step(self):
def training_step(self, batch, batch_idx):
if self.precision == 16:
...
------------
trainer
~~~~~~~
Pointer to the trainer
.. code-block:: python
def training_step(self):
def training_step(self, batch, batch_idx):
max_steps = self.trainer.max_steps
any_flag = self.trainer.any_flag
------------
use_amp
~~~~~~~
``True`` if using Automatic Mixed Precision (AMP)
------------
``True`` if using Automatic Mixed Precision (AMP)
prepare_data_per_node
~~~~~~~~~~~~~~~~~~~~~
If set to ``True`` will call ``prepare_data()`` on LOCAL_RANK=0 for every node.
If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0.
@ -1045,10 +1037,9 @@ If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0.
super().__init__()
self.prepare_data_per_node = True
------------
automatic_optimization
~~~~~~~~~~~~~~~~~~~~~~
When set to ``False``, Lightning does not automate the optimization process. This means you are responsible for handling
your optimizers. However, we do take care of precision and any accelerators used.
@ -1092,10 +1083,9 @@ Manual optimization is most useful for research topics like reinforcement learni
self.manual_backward(disc_loss)
opt_b.step()
--------------
example_input_array
~~~~~~~~~~~~~~~~~~~
Set and access example_input_array, which basically represents a single batch.
.. code-block:: python
@ -1109,28 +1099,13 @@ Set and access example_input_array, which basically represents a single batch.
# generate some images using the example_input_array
gen_images = self.generator(self.example_input_array)
--------------
datamodule
~~~~~~~~~~
Set or access your datamodule.
.. code-block:: python
def configure_optimizers(self):
num_training_samples = len(self.trainer.datamodule.train_dataloader())
...
--------------
model_size
~~~~~~~~~~
Get the model file size (in megabytes) using ``self.model_size`` inside LightningModule.
--------------
truncated_bptt_steps
^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~
Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of
a much longer sequence. This is made possible by passing training batches
@ -1205,7 +1180,8 @@ override the :meth:`pytorch_lightning.core.lightning.LightningModule.tbptt_split
--------------
Hooks
^^^^^
=====
This is the pseudocode to describe the structure of :meth:`~pytorch_lightning.trainer.Trainer.fit`.
The inputs and outputs of each function are not represented for simplicity. Please check each function's API reference
for more information.
@ -1287,17 +1263,20 @@ for more information.
on_epoch_start()
on_validation_epoch_start()
for batch in val_dataloader():
on_validation_batch_start()
val_outs = []
for batch_idx, batch in enumerate(val_dataloader()):
on_validation_batch_start(batch, batch_idx)
on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
batch = on_before_batch_transfer(batch)
batch = transfer_batch_to_device(batch)
batch = on_after_batch_transfer(batch)
validation_step()
out = validation_step(batch, batch_idx)
on_validation_batch_end()
validation_epoch_end()
on_validation_batch_end(batch, batch_idx)
val_outs.append(out)
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_epoch_end()

View File

@ -70,7 +70,6 @@ PyTorch Lightning
clouds/cloud_training
clouds/cluster
common/child_modules
common/debugging
common/early_stopping
common/hyperparameters

View File

@ -991,6 +991,10 @@ And pass the callbacks into the trainer
----------------
*************
Child Modules
*************
.. include:: ../common/child_modules.rst
----------------