Update child modules docs (#11198)
This commit is contained in:
parent
34c62da37d
commit
9092cf3a35
|
@ -2,11 +2,13 @@
|
|||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
#################
|
||||
Transfer Learning
|
||||
-----------------
|
||||
#################
|
||||
|
||||
***********************
|
||||
Using Pretrained Models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
***********************
|
||||
|
||||
Sometimes we want to use a LightningModule as a pretrained model. This is fine because
|
||||
a LightningModule is just a `torch.nn.Module`!
|
||||
|
@ -44,8 +46,8 @@ Let's use the `AutoEncoder` as a feature extractor in a separate model.
|
|||
|
||||
We used our pretrained Autoencoder (a LightningModule) for transfer learning!
|
||||
|
||||
Example: Imagenet (computer Vision)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Example: Imagenet (Computer Vision)
|
||||
===================================
|
||||
|
||||
.. testcode::
|
||||
:skipif: not _TORCHVISION_AVAILABLE
|
||||
|
@ -96,7 +98,8 @@ We used a pretrained model on imagenet, finetuned on CIFAR-10 to predict on CIFA
|
|||
In the non-academic world we would finetune on a tiny dataset you have and predict on your dataset.
|
||||
|
||||
Example: BERT (NLP)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
===================
|
||||
|
||||
Lightning is completely agnostic to what's used for transfer learning so long
|
||||
as it is a `torch.nn.Module` subclass.
|
||||
|
||||
|
|
|
@ -1,59 +1,41 @@
|
|||
.. testsetup:: *
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
|
||||
class LitMNIST(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def train_dataloader():
|
||||
pass
|
||||
|
||||
def val_dataloader():
|
||||
pass
|
||||
|
||||
def test_dataloader():
|
||||
pass
|
||||
|
||||
Child Modules
|
||||
-------------
|
||||
Research projects tend to test different approaches to the same dataset.
|
||||
This is very easy to do in Lightning with inheritance.
|
||||
|
||||
For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images.
|
||||
We are extending our Autoencoder from the `LitMNIST`-module which already defines all the dataloading.
|
||||
The only things that change in the `Autoencoder` model are the init, forward, training, validation and test step.
|
||||
For example, imagine we now want to train an ``AutoEncoder`` to use as a feature extractor for images.
|
||||
The only things that change in the ``LitAutoEncoder`` model are the init, forward, training, validation and test step.
|
||||
|
||||
.. testcode::
|
||||
.. code-block:: python
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
pass
|
||||
...
|
||||
|
||||
|
||||
class Decoder(torch.nn.Module):
|
||||
pass
|
||||
...
|
||||
|
||||
|
||||
class AutoEncoder(LitMNIST):
|
||||
class AutoEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.metric = MSE()
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
return self.decoder(self.encoder(x))
|
||||
|
||||
|
||||
class LitAutoEncoder(LightningModule):
|
||||
def __init__(self, auto_encoder):
|
||||
super().__init__()
|
||||
self.auto_encoder = auto_encoder
|
||||
self.metric = torch.nn.MSELoss()
|
||||
|
||||
def forward(self, x):
|
||||
return self.auto_encoder.encoder(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, _ = batch
|
||||
|
||||
representation = self.encoder(x)
|
||||
x_hat = self.decoder(representation)
|
||||
|
||||
x_hat = self.auto_encoder(x)
|
||||
loss = self.metric(x, x_hat)
|
||||
return loss
|
||||
|
||||
|
@ -65,25 +47,24 @@ The only things that change in the `Autoencoder` model are the init, forward, tr
|
|||
|
||||
def _shared_eval(self, batch, batch_idx, prefix):
|
||||
x, _ = batch
|
||||
representation = self.encoder(x)
|
||||
x_hat = self.decoder(representation)
|
||||
|
||||
x_hat = self.auto_encoder(x)
|
||||
loss = self.metric(x, x_hat)
|
||||
self.log(f"{prefix}_loss", loss)
|
||||
|
||||
|
||||
and we can train this using the same trainer
|
||||
and we can train this using the ``Trainer``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
autoencoder = AutoEncoder()
|
||||
auto_encoder = AutoEncoder()
|
||||
lightning_module = LitAutoEncoder(auto_encoder)
|
||||
trainer = Trainer()
|
||||
trainer.fit(autoencoder)
|
||||
trainer.fit(lightning_module, train_dataloader, val_dataloader)
|
||||
|
||||
And remember that the forward method should define the practical use of a LightningModule.
|
||||
In this case, we want to use the `AutoEncoder` to extract image representations
|
||||
And remember that the forward method should define the practical use of a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
||||
In this case, we want to use the ``LitAutoEncoder`` to extract image representations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
some_images = torch.Tensor(32, 1, 28, 28)
|
||||
representations = autoencoder(some_images)
|
||||
representations = lightning_module(some_images)
|
||||
|
|
|
@ -3,15 +3,17 @@
|
|||
|
||||
.. _lightning_module:
|
||||
|
||||
###############
|
||||
LightningModule
|
||||
===============
|
||||
###############
|
||||
|
||||
A :class:`~LightningModule` organizes your PyTorch code into 6 sections:
|
||||
|
||||
- Computations (init).
|
||||
- Train loop (training_step)
|
||||
- Validation loop (validation_step)
|
||||
- Test loop (test_step)
|
||||
- Prediction loop (predict_step)
|
||||
- Train Loop (training_step)
|
||||
- Validation Loop (validation_step)
|
||||
- Test Loop (test_step)
|
||||
- Prediction Loop (predict_step)
|
||||
- Optimizers and LR Schedulers (configure_optimizers)
|
||||
|
||||
|
|
||||
|
@ -85,8 +87,9 @@ Thus, to use Lightning, you just need to organize your code which takes about 30
|
|||
|
||||
------------
|
||||
|
||||
Minimal Example
|
||||
---------------
|
||||
***************
|
||||
Starter Example
|
||||
***************
|
||||
|
||||
Here are the only required methods.
|
||||
|
||||
|
@ -147,11 +150,13 @@ The LightningModule has many convenience methods, but the core ones you need to
|
|||
|
||||
----------
|
||||
|
||||
********
|
||||
Training
|
||||
--------
|
||||
********
|
||||
|
||||
Training Loop
|
||||
^^^^^^^^^^^^^
|
||||
=============
|
||||
|
||||
To activate the training loop, override the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method.
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -171,15 +176,14 @@ Under the hood, Lightning does the following (pseudocode):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
# put model in train mode
|
||||
# put model in train mode and enable gradient calculation
|
||||
model.train()
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
losses = []
|
||||
for batch in train_dataloader:
|
||||
# forward
|
||||
loss = training_step(batch)
|
||||
losses.append(loss.detach())
|
||||
outs = []
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
loss = training_step(batch, batch_idx)
|
||||
outs.append(loss.detach())
|
||||
|
||||
# clear gradients
|
||||
optimizer.zero_grad()
|
||||
|
@ -191,8 +195,9 @@ Under the hood, Lightning does the following (pseudocode):
|
|||
optimizer.step()
|
||||
|
||||
|
||||
Training Epoch-Level Metrics
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Train Epoch-level Metrics
|
||||
=========================
|
||||
|
||||
If you want to calculate epoch-level metrics and log them, use :meth:`~pytorch_lightning.core.lightning.LightningModule.log`.
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -213,10 +218,10 @@ requested metrics across a complete epoch and devices. Here's the pseudocode of
|
|||
.. code-block:: python
|
||||
|
||||
outs = []
|
||||
for batch in train_dataloader:
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
# forward
|
||||
out = training_step(val_batch)
|
||||
outs.append(out)
|
||||
loss = training_step(batch, batch_idx)
|
||||
outs.append(loss)
|
||||
|
||||
# clear gradients
|
||||
optimizer.zero_grad()
|
||||
|
@ -227,11 +232,12 @@ requested metrics across a complete epoch and devices. Here's the pseudocode of
|
|||
# update parameters
|
||||
optimizer.step()
|
||||
|
||||
epoch_metric = torch.mean(torch.stack([x["train_loss"] for x in outs]))
|
||||
epoch_metric = torch.mean(torch.stack([x for x in outs]))
|
||||
|
||||
Train Epoch-Level Operations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`.
|
||||
Train Epoch-level Operations
|
||||
============================
|
||||
|
||||
If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
|
||||
override the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end` method.
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -253,10 +259,10 @@ The matching pseudocode is:
|
|||
.. code-block:: python
|
||||
|
||||
outs = []
|
||||
for batch in train_dataloader:
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
# forward
|
||||
out = training_step(val_batch)
|
||||
outs.append(out)
|
||||
loss = training_step(batch, batch_idx)
|
||||
outs.append(loss)
|
||||
|
||||
# clear gradients
|
||||
optimizer.zero_grad()
|
||||
|
@ -270,7 +276,8 @@ The matching pseudocode is:
|
|||
training_epoch_end(outs)
|
||||
|
||||
Training with DataParallel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
==========================
|
||||
|
||||
When training using a ``strategy`` that splits data from each batch across GPUs, sometimes you might
|
||||
need to aggregate them on the main GPU for processing (DP, or DDP2).
|
||||
|
||||
|
@ -309,12 +316,12 @@ Here is the Lightning training pseudo-code for DP:
|
|||
.. code-block:: python
|
||||
|
||||
outs = []
|
||||
for train_batch in train_dataloader:
|
||||
for batch_idx, train_batch in enumerate(train_dataloader):
|
||||
batches = split_batch(train_batch)
|
||||
dp_outs = []
|
||||
for sub_batch in batches:
|
||||
# 1
|
||||
dp_out = training_step(sub_batch)
|
||||
dp_out = training_step(sub_batch, batch_idx)
|
||||
dp_outs.append(dp_out)
|
||||
|
||||
# 2
|
||||
|
@ -327,8 +334,13 @@ Here is the Lightning training pseudo-code for DP:
|
|||
|
||||
------------------
|
||||
|
||||
**********
|
||||
Validation
|
||||
**********
|
||||
|
||||
Validation Loop
|
||||
^^^^^^^^^^^^^^^
|
||||
===============
|
||||
|
||||
To activate the validation loop while training, override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` method.
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -345,8 +357,8 @@ Under the hood, Lightning does the following (pseudocode):
|
|||
.. code-block:: python
|
||||
|
||||
# ...
|
||||
for batch in train_dataloader:
|
||||
loss = model.training_step()
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
loss = model.training_step(batch, batch_idx)
|
||||
loss.backward()
|
||||
# ...
|
||||
|
||||
|
@ -356,8 +368,8 @@ Under the hood, Lightning does the following (pseudocode):
|
|||
model.eval()
|
||||
|
||||
# ----------------- VAL LOOP ---------------
|
||||
for val_batch in model.val_dataloader:
|
||||
val_out = model.validation_step(val_batch)
|
||||
for val_batch_idx, val_batch in enumerate(val_dataloader):
|
||||
val_out = model.validation_step(val_batch, val_batch_idx)
|
||||
# ----------------- VAL LOOP ---------------
|
||||
|
||||
# enable grads + batchnorm + dropout
|
||||
|
@ -373,8 +385,18 @@ and calling :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`.
|
|||
trainer = Trainer()
|
||||
trainer.validate(model)
|
||||
|
||||
Validation Epoch-Level Metrics
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. note::
|
||||
|
||||
It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
|
||||
This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a
|
||||
multi-device setting, samples could occur duplicated when :class:`~torch.utils.data.distributed.DistributedSampler`
|
||||
is used, for eg. with ``strategy="ddp"``. It replicates some samples on some devices to make sure all devices have
|
||||
same batch size in case of uneven inputs.
|
||||
|
||||
|
||||
Validation Epoch-level Metrics
|
||||
==============================
|
||||
|
||||
If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`,
|
||||
override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` method.
|
||||
|
||||
|
@ -393,7 +415,8 @@ override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation
|
|||
...
|
||||
|
||||
Validating with DataParallel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
============================
|
||||
|
||||
When training using a ``strategy`` that splits data from each batch across GPUs, sometimes you might
|
||||
need to aggregate them on the main GPU for processing (DP, or DDP2).
|
||||
|
||||
|
@ -450,8 +473,13 @@ Here is the Lightning validation pseudo-code for DP:
|
|||
|
||||
----------------
|
||||
|
||||
*******
|
||||
Testing
|
||||
*******
|
||||
|
||||
Test Loop
|
||||
^^^^^^^^^
|
||||
=========
|
||||
|
||||
The process for enabling a test loop is the same as the process for enabling a validation loop. Please refer to
|
||||
the section above for details. For this you need to override the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` method.
|
||||
|
||||
|
@ -482,118 +510,79 @@ There are two ways to call ``test()``:
|
|||
trainer = Trainer()
|
||||
trainer.test(model, dataloaders=test_dataloader)
|
||||
|
||||
.. note::
|
||||
|
||||
It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
|
||||
This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a
|
||||
multi-device setting, samples could occur duplicated when :class:`~torch.utils.data.distributed.DistributedSampler`
|
||||
is used, for eg. with ``strategy="ddp"``. It replicates some samples on some devices to make sure all devices have
|
||||
same batch size in case of uneven inputs.
|
||||
|
||||
|
||||
----------
|
||||
|
||||
Inference (Prediction Loop)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
For research, LightningModules are best structured as systems.
|
||||
*********
|
||||
Inference
|
||||
*********
|
||||
|
||||
Prediction Loop
|
||||
===============
|
||||
|
||||
By default, the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method runs the
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method. In order to customize this behaviour,
|
||||
simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method.
|
||||
|
||||
For the example let's override ``predict_step`` and try out `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Autoencoder(pl.LightningModule):
|
||||
def __init__(self, latent_dim=2):
|
||||
class LitMCdropoutModel(pl.LightningModule):
|
||||
def __init__(self, model, mc_iteration):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
|
||||
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
|
||||
self.model = model
|
||||
self.dropout = nn.Dropout()
|
||||
self.mc_iteration = mc_iteration
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, _ = batch
|
||||
def predict_step(self, batch, batch_idx):
|
||||
# enable Monte Carlo Dropout
|
||||
self.dropout.train()
|
||||
|
||||
# encode
|
||||
x = x.view(x.size(0), -1)
|
||||
z = self.encoder(x)
|
||||
# take average of `self.mc_iteration` iterations
|
||||
pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
|
||||
return pred
|
||||
|
||||
# decode
|
||||
recons = self.decoder(z)
|
||||
|
||||
# reconstruction
|
||||
reconstruction_loss = nn.functional.mse_loss(recons, x)
|
||||
return reconstruction_loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, _ = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
z = self.encoder(x)
|
||||
recons = self.decoder(z)
|
||||
reconstruction_loss = nn.functional.mse_loss(recons, x)
|
||||
self.log("val_reconstruction", reconstruction_loss)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
x, _ = batch
|
||||
|
||||
# encode
|
||||
# for predictions, we could return the embedding or the reconstruction or both based on our need.
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.encoder(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.0002)
|
||||
|
||||
Which can be trained like this:
|
||||
Under the hood, Lightning does the following (pseudocode):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
autoencoder = Autoencoder()
|
||||
trainer = pl.Trainer(gpus=1)
|
||||
trainer.fit(autoencoder, train_dataloader, val_dataloader)
|
||||
# disable grads + batchnorm + dropout
|
||||
torch.set_grad_enabled(False)
|
||||
model.eval()
|
||||
all_preds = []
|
||||
|
||||
This simple model generates examples that look like this (the encoders and decoders are too weak)
|
||||
for batch_idx, batch in enumerate(predict_dataloader):
|
||||
pred = model.predict_step(batch, batch_idx)
|
||||
all_preds.append(pred)
|
||||
|
||||
.. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/ae_docs.png
|
||||
:width: 300
|
||||
|
||||
The methods above are part of the LightningModule interface:
|
||||
|
||||
- training_step
|
||||
- validation_step
|
||||
- test_step
|
||||
- predict_step
|
||||
- configure_optimizers
|
||||
|
||||
Note that in this case, the train loop and val loop are exactly the same. We can, of course, reuse this code.
|
||||
There are two ways to call ``predict()``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Autoencoder(pl.LightningModule):
|
||||
def __init__(self, latent_dim=2):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
|
||||
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
|
||||
# call after training
|
||||
trainer = Trainer()
|
||||
trainer.fit(model)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.shared_step(batch)
|
||||
# automatically auto-loads the best weights from the previous run
|
||||
predictions = trainer.predict(dataloaders=predict_dataloader)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss = self.shared_step(batch)
|
||||
self.log("val_loss", loss)
|
||||
|
||||
def shared_step(self, batch):
|
||||
x, _ = batch
|
||||
|
||||
# encode
|
||||
x = x.view(x.size(0), -1)
|
||||
z = self.encoder(x)
|
||||
|
||||
# decode
|
||||
recons = self.decoder(z)
|
||||
|
||||
# loss
|
||||
return nn.functional.mse_loss(recons, x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.0002)
|
||||
|
||||
We create a new method called ``shared_step`` that all loops can use. This method name is arbitrary and NOT reserved.
|
||||
# or call with pretrained model
|
||||
model = MyLightningModule.load_from_checkpoint(PATH)
|
||||
trainer = Trainer()
|
||||
predictions = trainer.predict(model, dataloaders=test_dataloader)
|
||||
|
||||
Inference in Research
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
=====================
|
||||
|
||||
If you want to perform inference with the system, you can add a ``forward`` method to the LightningModule.
|
||||
|
||||
.. note:: When using forward, you are responsible to call :func:`~torch.nn.Module.eval` and use the :func:`~torch.no_grad` context manager.
|
||||
|
@ -644,7 +633,8 @@ In the case where you want to scale your inference, you should be using
|
|||
trainer.predict(model, data_module)
|
||||
|
||||
Inference in Production
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
=======================
|
||||
|
||||
For cases like production, you might want to iterate different models inside a LightningModule.
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -715,8 +705,8 @@ Tasks can be arbitrarily complex such as implementing GAN training, self-supervi
|
|||
When used like this, the model can be separated from the Task and thus used in production without needing to keep it in
|
||||
a ``LightningModule``.
|
||||
|
||||
- You can export to onnx using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_onnx`.
|
||||
- Or trace using Jit using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript`.
|
||||
- You can export to `ONNX <https://pytorch.org/docs/stable/onnx.html>`_ using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_onnx`.
|
||||
- Or trace using `TorchScript <https://pytorch.org/docs/stable/jit.html>`_ using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript`.
|
||||
- Or run in the Python runtime.
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -731,13 +721,25 @@ a ``LightningModule``.
|
|||
with torch.no_grad():
|
||||
y_hat = model(x)
|
||||
|
||||
|
||||
-----------
|
||||
|
||||
|
||||
*************
|
||||
Child Modules
|
||||
*************
|
||||
|
||||
.. include:: ../common/child_modules.rst
|
||||
|
||||
-----------
|
||||
|
||||
*******************
|
||||
LightningModule API
|
||||
-------------------
|
||||
*******************
|
||||
|
||||
|
||||
Methods
|
||||
^^^^^^^
|
||||
=======
|
||||
|
||||
all_gather
|
||||
~~~~~~~~~~
|
||||
|
@ -900,65 +902,61 @@ validation_epoch_end
|
|||
.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_epoch_end
|
||||
:noindex:
|
||||
|
||||
------------
|
||||
-----------
|
||||
|
||||
Properties
|
||||
^^^^^^^^^^
|
||||
These are properties available in a LightningModule.
|
||||
==========
|
||||
|
||||
-----------
|
||||
These are properties available in a LightningModule.
|
||||
|
||||
current_epoch
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
The current epoch
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self):
|
||||
def training_step(self, batch, batch_idx):
|
||||
if self.current_epoch == 0:
|
||||
...
|
||||
|
||||
-------------
|
||||
|
||||
device
|
||||
~~~~~~
|
||||
|
||||
The device the module is on. Use it to keep your code device agnostic.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self):
|
||||
def training_step(self, batch, batch_idx):
|
||||
z = torch.rand(2, 3, device=self.device)
|
||||
|
||||
-------------
|
||||
|
||||
global_rank
|
||||
~~~~~~~~~~~
|
||||
|
||||
The ``global_rank`` is the index of the current process across all nodes and devices.
|
||||
Lightning will perform some operations such as logging, weight checkpointing only when ``global_rank=0``. You
|
||||
usually do not need to use this property, but it is useful to know how to access it if needed.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self):
|
||||
def training_step(self, batch, batch_idx):
|
||||
if self.global_rank == 0:
|
||||
# do something only once across all the nodes
|
||||
self.log("global_step", self.trainer.global_step)
|
||||
|
||||
-------------
|
||||
|
||||
global_step
|
||||
~~~~~~~~~~~
|
||||
|
||||
The current step (does not reset each epoch)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self):
|
||||
def training_step(self, batch, batch_idx):
|
||||
self.logger.experiment.log_image(..., step=self.global_step)
|
||||
|
||||
-------------
|
||||
|
||||
hparams
|
||||
~~~~~~~
|
||||
|
||||
The arguments passed through ``LightningModule.__init__()`` and saved by calling
|
||||
:meth:`~pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin.save_hyperparameters` could be accessed by the ``hparams`` attribute.
|
||||
|
||||
|
@ -971,70 +969,64 @@ The arguments passed through ``LightningModule.__init__()`` and saved by calling
|
|||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
|
||||
--------------
|
||||
|
||||
logger
|
||||
~~~~~~
|
||||
|
||||
The current logger being used (tensorboard or other supported logger)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self):
|
||||
def training_step(self, batch, batch_idx):
|
||||
# the generic logger (same no matter if tensorboard or other supported logger)
|
||||
self.logger
|
||||
|
||||
# the particular logger
|
||||
tensorboard_logger = self.logger.experiment
|
||||
|
||||
--------------
|
||||
|
||||
local_rank
|
||||
~~~~~~~~~~~
|
||||
|
||||
The ``global_rank`` is the index of the current process across all the devices for the current node.
|
||||
You usually do not need to use this property, but it is useful to know how to access it if needed.
|
||||
For example, if using 10 machines (or nodes), the GPU at index 0 on each machine has local_rank = 0.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self):
|
||||
def training_step(self, batch, batch_idx):
|
||||
if self.global_rank == 0:
|
||||
# do something only once across each node
|
||||
self.log("global_step", self.trainer.global_step)
|
||||
|
||||
-----------
|
||||
|
||||
precision
|
||||
~~~~~~~~~
|
||||
|
||||
The type of precision used:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self):
|
||||
def training_step(self, batch, batch_idx):
|
||||
if self.precision == 16:
|
||||
...
|
||||
|
||||
------------
|
||||
|
||||
trainer
|
||||
~~~~~~~
|
||||
|
||||
Pointer to the trainer
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self):
|
||||
def training_step(self, batch, batch_idx):
|
||||
max_steps = self.trainer.max_steps
|
||||
any_flag = self.trainer.any_flag
|
||||
|
||||
------------
|
||||
|
||||
use_amp
|
||||
~~~~~~~
|
||||
``True`` if using Automatic Mixed Precision (AMP)
|
||||
|
||||
------------
|
||||
``True`` if using Automatic Mixed Precision (AMP)
|
||||
|
||||
prepare_data_per_node
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If set to ``True`` will call ``prepare_data()`` on LOCAL_RANK=0 for every node.
|
||||
If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0.
|
||||
|
||||
|
@ -1045,10 +1037,9 @@ If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0.
|
|||
super().__init__()
|
||||
self.prepare_data_per_node = True
|
||||
|
||||
------------
|
||||
|
||||
automatic_optimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
When set to ``False``, Lightning does not automate the optimization process. This means you are responsible for handling
|
||||
your optimizers. However, we do take care of precision and any accelerators used.
|
||||
|
||||
|
@ -1092,10 +1083,9 @@ Manual optimization is most useful for research topics like reinforcement learni
|
|||
self.manual_backward(disc_loss)
|
||||
opt_b.step()
|
||||
|
||||
--------------
|
||||
|
||||
example_input_array
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Set and access example_input_array, which basically represents a single batch.
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -1109,28 +1099,13 @@ Set and access example_input_array, which basically represents a single batch.
|
|||
# generate some images using the example_input_array
|
||||
gen_images = self.generator(self.example_input_array)
|
||||
|
||||
--------------
|
||||
|
||||
datamodule
|
||||
~~~~~~~~~~
|
||||
Set or access your datamodule.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def configure_optimizers(self):
|
||||
num_training_samples = len(self.trainer.datamodule.train_dataloader())
|
||||
...
|
||||
|
||||
--------------
|
||||
|
||||
model_size
|
||||
~~~~~~~~~~
|
||||
|
||||
Get the model file size (in megabytes) using ``self.model_size`` inside LightningModule.
|
||||
|
||||
--------------
|
||||
|
||||
truncated_bptt_steps
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of
|
||||
a much longer sequence. This is made possible by passing training batches
|
||||
|
@ -1205,7 +1180,8 @@ override the :meth:`pytorch_lightning.core.lightning.LightningModule.tbptt_split
|
|||
--------------
|
||||
|
||||
Hooks
|
||||
^^^^^
|
||||
=====
|
||||
|
||||
This is the pseudocode to describe the structure of :meth:`~pytorch_lightning.trainer.Trainer.fit`.
|
||||
The inputs and outputs of each function are not represented for simplicity. Please check each function's API reference
|
||||
for more information.
|
||||
|
@ -1287,17 +1263,20 @@ for more information.
|
|||
on_epoch_start()
|
||||
on_validation_epoch_start()
|
||||
|
||||
for batch in val_dataloader():
|
||||
on_validation_batch_start()
|
||||
val_outs = []
|
||||
for batch_idx, batch in enumerate(val_dataloader()):
|
||||
on_validation_batch_start(batch, batch_idx)
|
||||
|
||||
on_before_batch_transfer()
|
||||
transfer_batch_to_device()
|
||||
on_after_batch_transfer()
|
||||
batch = on_before_batch_transfer(batch)
|
||||
batch = transfer_batch_to_device(batch)
|
||||
batch = on_after_batch_transfer(batch)
|
||||
|
||||
validation_step()
|
||||
out = validation_step(batch, batch_idx)
|
||||
|
||||
on_validation_batch_end()
|
||||
validation_epoch_end()
|
||||
on_validation_batch_end(batch, batch_idx)
|
||||
val_outs.append(out)
|
||||
|
||||
validation_epoch_end(val_outs)
|
||||
|
||||
on_validation_epoch_end()
|
||||
on_epoch_end()
|
||||
|
|
|
@ -70,7 +70,6 @@ PyTorch Lightning
|
|||
|
||||
clouds/cloud_training
|
||||
clouds/cluster
|
||||
common/child_modules
|
||||
common/debugging
|
||||
common/early_stopping
|
||||
common/hyperparameters
|
||||
|
|
|
@ -991,6 +991,10 @@ And pass the callbacks into the trainer
|
|||
|
||||
----------------
|
||||
|
||||
*************
|
||||
Child Modules
|
||||
*************
|
||||
|
||||
.. include:: ../common/child_modules.rst
|
||||
|
||||
----------------
|
||||
|
|
Loading…
Reference in New Issue