1668 lines
43 KiB
ReStructuredText
1668 lines
43 KiB
ReStructuredText
.. role:: hidden
|
|
:class: hidden-section
|
|
|
|
.. _lightning_module:
|
|
|
|
###############
|
|
LightningModule
|
|
###############
|
|
|
|
A :class:`~LightningModule` organizes your PyTorch code into 6 sections:
|
|
|
|
- Computations (init).
|
|
- Train Loop (training_step)
|
|
- Validation Loop (validation_step)
|
|
- Test Loop (test_step)
|
|
- Prediction Loop (predict_step)
|
|
- Optimizers and LR Schedulers (configure_optimizers)
|
|
|
|
|
|
|
|
|
.. raw:: html
|
|
|
|
<video width="100%" max-width="400px" controls autoplay muted playsinline src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v"></video>
|
|
|
|
|
|
|
|
|
Notice a few things.
|
|
|
|
1. It is the SAME code.
|
|
2. The PyTorch code IS NOT abstracted - just organized.
|
|
3. All the other code that's not in the :class:`~LightningModule`
|
|
has been automated for you by the Trainer.
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
net = Net()
|
|
trainer = Trainer()
|
|
trainer.fit(net)
|
|
|
|
4. There are no ``.cuda()`` or ``.to(device)`` calls required. Lightning does these for you.
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
# don't do in Lightning
|
|
x = torch.Tensor(2, 3)
|
|
x = x.cuda()
|
|
x = x.to(device)
|
|
|
|
# do this instead
|
|
x = x # leave it alone!
|
|
|
|
# or to init a new tensor
|
|
new_x = torch.Tensor(2, 3)
|
|
new_x = new_x.type_as(x)
|
|
|
|
5. When running under a distributed strategy, Lightning handles the distributed sampler for you by default.
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
# Don't do in Lightning...
|
|
data = MNIST(...)
|
|
sampler = DistributedSampler(data)
|
|
DataLoader(data, sampler=sampler)
|
|
|
|
# do this instead
|
|
data = MNIST(...)
|
|
DataLoader(data)
|
|
|
|
6. A :class:`~LightningModule` is a :class:`torch.nn.Module` but with added functionality. Use it as such!
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
net = Net.load_from_checkpoint(PATH)
|
|
net.freeze()
|
|
out = net(x)
|
|
|
|
Thus, to use Lightning, you just need to organize your code which takes about 30 minutes,
|
|
(and let's be real, you probably should do anyway).
|
|
|
|
------------
|
|
|
|
***************
|
|
Starter Example
|
|
***************
|
|
|
|
Here are the only required methods.
|
|
|
|
.. code-block:: python
|
|
|
|
import pytorch_lightning as pl
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class LitModel(pl.LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = nn.Linear(28 * 28, 10)
|
|
|
|
def forward(self, x):
|
|
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
return loss
|
|
|
|
def configure_optimizers(self):
|
|
return torch.optim.Adam(self.parameters(), lr=0.02)
|
|
|
|
Which you can train by doing:
|
|
|
|
.. code-block:: python
|
|
|
|
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
|
|
trainer = pl.Trainer(max_epochs=1)
|
|
model = LitModel()
|
|
|
|
trainer.fit(model, train_dataloaders=train_loader)
|
|
|
|
The LightningModule has many convenience methods, but the core ones you need to know about are:
|
|
|
|
.. list-table::
|
|
:widths: 50 50
|
|
:header-rows: 1
|
|
|
|
* - Name
|
|
- Description
|
|
* - init
|
|
- Define computations here
|
|
* - forward
|
|
- Use for inference only (separate from training_step)
|
|
* - training_step
|
|
- the complete training loop
|
|
* - validation_step
|
|
- the complete validation loop
|
|
* - test_step
|
|
- the complete test loop
|
|
* - predict_step
|
|
- the complete prediction loop
|
|
* - configure_optimizers
|
|
- define optimizers and LR schedulers
|
|
|
|
----------
|
|
|
|
********
|
|
Training
|
|
********
|
|
|
|
Training Loop
|
|
=============
|
|
|
|
To activate the training loop, override the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method.
|
|
|
|
.. code-block:: python
|
|
|
|
class LitClassifier(pl.LightningModule):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
return loss
|
|
|
|
Under the hood, Lightning does the following (pseudocode):
|
|
|
|
.. code-block:: python
|
|
|
|
# put model in train mode and enable gradient calculation
|
|
model.train()
|
|
torch.set_grad_enabled(True)
|
|
|
|
outs = []
|
|
for batch_idx, batch in enumerate(train_dataloader):
|
|
loss = training_step(batch, batch_idx)
|
|
outs.append(loss.detach())
|
|
|
|
# clear gradients
|
|
optimizer.zero_grad()
|
|
|
|
# backward
|
|
loss.backward()
|
|
|
|
# update parameters
|
|
optimizer.step()
|
|
|
|
|
|
Train Epoch-level Metrics
|
|
=========================
|
|
|
|
If you want to calculate epoch-level metrics and log them, use :meth:`~pytorch_lightning.core.lightning.LightningModule.log`.
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
|
|
# logs metrics for each training_step,
|
|
# and the average across the epoch, to the progress bar and logger
|
|
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
|
return loss
|
|
|
|
The :meth:`~pytorch_lightning.core.lightning.LightningModule.log` object automatically reduces the
|
|
requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood:
|
|
|
|
.. code-block:: python
|
|
|
|
outs = []
|
|
for batch_idx, batch in enumerate(train_dataloader):
|
|
# forward
|
|
loss = training_step(batch, batch_idx)
|
|
outs.append(loss)
|
|
|
|
# clear gradients
|
|
optimizer.zero_grad()
|
|
|
|
# backward
|
|
loss.backward()
|
|
|
|
# update parameters
|
|
optimizer.step()
|
|
|
|
epoch_metric = torch.mean(torch.stack([x for x in outs]))
|
|
|
|
Train Epoch-level Operations
|
|
============================
|
|
|
|
If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
|
|
override the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end` method.
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
preds = ...
|
|
return {"loss": loss, "other_stuff": preds}
|
|
|
|
|
|
def training_epoch_end(self, training_step_outputs):
|
|
all_preds = torch.stack(training_step_outputs)
|
|
...
|
|
|
|
The matching pseudocode is:
|
|
|
|
.. code-block:: python
|
|
|
|
outs = []
|
|
for batch_idx, batch in enumerate(train_dataloader):
|
|
# forward
|
|
loss = training_step(batch, batch_idx)
|
|
outs.append(loss)
|
|
|
|
# clear gradients
|
|
optimizer.zero_grad()
|
|
|
|
# backward
|
|
loss.backward()
|
|
|
|
# update parameters
|
|
optimizer.step()
|
|
|
|
training_epoch_end(outs)
|
|
|
|
Training with DataParallel
|
|
==========================
|
|
|
|
When training using a ``strategy`` that splits data from each batch across GPUs, sometimes you might
|
|
need to aggregate them on the main GPU for processing (DP, or DDP2).
|
|
|
|
In this case, implement the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`
|
|
method which will have outputs from all the devices and you can accumulate to get the effective results.
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
pred = ...
|
|
return {"loss": loss, "pred": pred}
|
|
|
|
|
|
def training_step_end(self, batch_parts):
|
|
# predictions from each GPU
|
|
predictions = batch_parts["pred"]
|
|
# losses from each GPU
|
|
losses = batch_parts["loss"]
|
|
|
|
gpu_0_prediction = predictions[0]
|
|
gpu_1_prediction = predictions[1]
|
|
|
|
# do something with both outputs
|
|
return (losses[0] + losses[1]) / 2
|
|
|
|
|
|
def training_epoch_end(self, training_step_outputs):
|
|
for out in training_step_outputs:
|
|
...
|
|
|
|
Here is the Lightning training pseudo-code for DP:
|
|
|
|
.. code-block:: python
|
|
|
|
outs = []
|
|
for batch_idx, train_batch in enumerate(train_dataloader):
|
|
batches = split_batch(train_batch)
|
|
dp_outs = []
|
|
for sub_batch in batches:
|
|
# 1
|
|
dp_out = training_step(sub_batch, batch_idx)
|
|
dp_outs.append(dp_out)
|
|
|
|
# 2
|
|
out = training_step_end(dp_outs)
|
|
outs.append(out)
|
|
|
|
# do something with the outputs for all batches
|
|
# 3
|
|
training_epoch_end(outs)
|
|
|
|
------------------
|
|
|
|
**********
|
|
Validation
|
|
**********
|
|
|
|
Validation Loop
|
|
===============
|
|
|
|
To activate the validation loop while training, override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` method.
|
|
|
|
.. code-block:: python
|
|
|
|
class LitModel(pl.LightningModule):
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
self.log("val_loss", loss)
|
|
|
|
Under the hood, Lightning does the following (pseudocode):
|
|
|
|
.. code-block:: python
|
|
|
|
# ...
|
|
for batch_idx, batch in enumerate(train_dataloader):
|
|
loss = model.training_step(batch, batch_idx)
|
|
loss.backward()
|
|
# ...
|
|
|
|
if validate_at_some_point:
|
|
# disable grads + batchnorm + dropout
|
|
torch.set_grad_enabled(False)
|
|
model.eval()
|
|
|
|
# ----------------- VAL LOOP ---------------
|
|
for val_batch_idx, val_batch in enumerate(val_dataloader):
|
|
val_out = model.validation_step(val_batch, val_batch_idx)
|
|
# ----------------- VAL LOOP ---------------
|
|
|
|
# enable grads + batchnorm + dropout
|
|
torch.set_grad_enabled(True)
|
|
model.train()
|
|
|
|
You can also run just the validation loop on your validation dataloaders by overriding :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`
|
|
and calling :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`.
|
|
|
|
.. code-block:: python
|
|
|
|
model = Model()
|
|
trainer = Trainer()
|
|
trainer.validate(model)
|
|
|
|
.. note::
|
|
|
|
It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
|
|
This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a
|
|
multi-device setting, samples could occur duplicated when :class:`~torch.utils.data.distributed.DistributedSampler`
|
|
is used, for eg. with ``strategy="ddp"``. It replicates some samples on some devices to make sure all devices have
|
|
same batch size in case of uneven inputs.
|
|
|
|
|
|
Validation Epoch-level Metrics
|
|
==============================
|
|
|
|
If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`,
|
|
override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` method. Note that this method is called before :meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end`.
|
|
|
|
.. code-block:: python
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
pred = ...
|
|
return pred
|
|
|
|
|
|
def validation_epoch_end(self, validation_step_outputs):
|
|
all_preds = torch.stack(validation_step_outputs)
|
|
...
|
|
|
|
Validating with DataParallel
|
|
============================
|
|
|
|
When training using a ``strategy`` that splits data from each batch across GPUs, sometimes you might
|
|
need to aggregate them on the main GPU for processing (DP, or DDP2).
|
|
|
|
In this case, implement the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step_end`
|
|
method which will have outputs from all the devices and you can accumulate to get the effective results.
|
|
|
|
.. code-block:: python
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
pred = ...
|
|
return {"loss": loss, "pred": pred}
|
|
|
|
|
|
def validation_step_end(self, batch_parts):
|
|
# predictions from each GPU
|
|
predictions = batch_parts["pred"]
|
|
# losses from each GPU
|
|
losses = batch_parts["loss"]
|
|
|
|
gpu_0_prediction = predictions[0]
|
|
gpu_1_prediction = predictions[1]
|
|
|
|
# do something with both outputs
|
|
return (losses[0] + losses[1]) / 2
|
|
|
|
|
|
def validation_epoch_end(self, validation_step_outputs):
|
|
for out in validation_step_outputs:
|
|
...
|
|
|
|
Here is the Lightning validation pseudo-code for DP:
|
|
|
|
.. code-block:: python
|
|
|
|
outs = []
|
|
for batch in dataloader:
|
|
batches = split_batch(batch)
|
|
dp_outs = []
|
|
for sub_batch in batches:
|
|
# 1
|
|
dp_out = validation_step(sub_batch)
|
|
dp_outs.append(dp_out)
|
|
|
|
# 2
|
|
out = validation_step_end(dp_outs)
|
|
outs.append(out)
|
|
|
|
# do something with the outputs for all batches
|
|
# 3
|
|
validation_epoch_end(outs)
|
|
|
|
----------------
|
|
|
|
*******
|
|
Testing
|
|
*******
|
|
|
|
Test Loop
|
|
=========
|
|
|
|
The process for enabling a test loop is the same as the process for enabling a validation loop. Please refer to
|
|
the section above for details. For this you need to override the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` method.
|
|
|
|
The only difference is that the test loop is only called when :meth:`~pytorch_lightning.trainer.trainer.Trainer.test` is used.
|
|
|
|
.. code-block:: python
|
|
|
|
model = Model()
|
|
trainer = Trainer()
|
|
trainer.fit(model)
|
|
|
|
# automatically loads the best weights for you
|
|
trainer.test(model)
|
|
|
|
There are two ways to call ``test()``:
|
|
|
|
.. code-block:: python
|
|
|
|
# call after training
|
|
trainer = Trainer()
|
|
trainer.fit(model)
|
|
|
|
# automatically auto-loads the best weights from the previous run
|
|
trainer.test(dataloaders=test_dataloader)
|
|
|
|
# or call with pretrained model
|
|
model = MyLightningModule.load_from_checkpoint(PATH)
|
|
trainer = Trainer()
|
|
trainer.test(model, dataloaders=test_dataloader)
|
|
|
|
.. note::
|
|
|
|
It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
|
|
This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a
|
|
multi-device setting, samples could occur duplicated when :class:`~torch.utils.data.distributed.DistributedSampler`
|
|
is used, for eg. with ``strategy="ddp"``. It replicates some samples on some devices to make sure all devices have
|
|
same batch size in case of uneven inputs.
|
|
|
|
|
|
----------
|
|
|
|
*********
|
|
Inference
|
|
*********
|
|
|
|
Prediction Loop
|
|
===============
|
|
|
|
By default, the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method runs the
|
|
:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method. In order to customize this behaviour,
|
|
simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method.
|
|
|
|
For the example let's override ``predict_step`` and try out `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_:
|
|
|
|
.. code-block:: python
|
|
|
|
class LitMCdropoutModel(pl.LightningModule):
|
|
def __init__(self, model, mc_iteration):
|
|
super().__init__()
|
|
self.model = model
|
|
self.dropout = nn.Dropout()
|
|
self.mc_iteration = mc_iteration
|
|
|
|
def predict_step(self, batch, batch_idx):
|
|
# enable Monte Carlo Dropout
|
|
self.dropout.train()
|
|
|
|
# take average of `self.mc_iteration` iterations
|
|
pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
|
|
return pred
|
|
|
|
Under the hood, Lightning does the following (pseudocode):
|
|
|
|
.. code-block:: python
|
|
|
|
# disable grads + batchnorm + dropout
|
|
torch.set_grad_enabled(False)
|
|
model.eval()
|
|
all_preds = []
|
|
|
|
for batch_idx, batch in enumerate(predict_dataloader):
|
|
pred = model.predict_step(batch, batch_idx)
|
|
all_preds.append(pred)
|
|
|
|
There are two ways to call ``predict()``:
|
|
|
|
.. code-block:: python
|
|
|
|
# call after training
|
|
trainer = Trainer()
|
|
trainer.fit(model)
|
|
|
|
# automatically auto-loads the best weights from the previous run
|
|
predictions = trainer.predict(dataloaders=predict_dataloader)
|
|
|
|
# or call with pretrained model
|
|
model = MyLightningModule.load_from_checkpoint(PATH)
|
|
trainer = Trainer()
|
|
predictions = trainer.predict(model, dataloaders=test_dataloader)
|
|
|
|
Inference in Research
|
|
=====================
|
|
|
|
If you want to perform inference with the system, you can add a ``forward`` method to the LightningModule.
|
|
|
|
.. note:: When using forward, you are responsible to call :func:`~torch.nn.Module.eval` and use the :func:`~torch.no_grad` context manager.
|
|
|
|
.. code-block:: python
|
|
|
|
class Autoencoder(pl.LightningModule):
|
|
def forward(self, x):
|
|
return self.decoder(x)
|
|
|
|
|
|
model = Autoencoder()
|
|
model.eval()
|
|
with torch.no_grad():
|
|
reconstruction = model(embedding)
|
|
|
|
The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure,
|
|
such as text generation:
|
|
|
|
.. code-block:: python
|
|
|
|
class Seq2Seq(pl.LightningModule):
|
|
def forward(self, x):
|
|
embeddings = self(x)
|
|
hidden_states = self.encoder(embeddings)
|
|
for h in hidden_states:
|
|
# decode
|
|
...
|
|
return decoded
|
|
|
|
In the case where you want to scale your inference, you should be using
|
|
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step`.
|
|
|
|
.. code-block:: python
|
|
|
|
class Autoencoder(pl.LightningModule):
|
|
def forward(self, x):
|
|
return self.decoder(x)
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
|
# this calls forward
|
|
return self(batch)
|
|
|
|
|
|
data_module = ...
|
|
model = Autoencoder()
|
|
trainer = Trainer(accelerator="gpu", devices=2)
|
|
trainer.predict(model, data_module)
|
|
|
|
Inference in Production
|
|
=======================
|
|
|
|
For cases like production, you might want to iterate different models inside a LightningModule.
|
|
|
|
.. code-block:: python
|
|
|
|
from torchmetrics.functional import accuracy
|
|
|
|
|
|
class ClassificationTask(pl.LightningModule):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
loss, acc = self._shared_eval_step(batch, batch_idx)
|
|
metrics = {"val_acc": acc, "val_loss": loss}
|
|
self.log_dict(metrics)
|
|
return metrics
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
loss, acc = self._shared_eval_step(batch, batch_idx)
|
|
metrics = {"test_acc": acc, "test_loss": loss}
|
|
self.log_dict(metrics)
|
|
return metrics
|
|
|
|
def _shared_eval_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
loss = F.cross_entropy(y_hat, y)
|
|
acc = accuracy(y_hat, y)
|
|
return loss, acc
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
|
x, y = batch
|
|
y_hat = self.model(x)
|
|
return y_hat
|
|
|
|
def configure_optimizers(self):
|
|
return torch.optim.Adam(self.model.parameters(), lr=0.02)
|
|
|
|
Then pass in any arbitrary model to be fit with this task
|
|
|
|
.. code-block:: python
|
|
|
|
for model in [resnet50(), vgg16(), BidirectionalRNN()]:
|
|
task = ClassificationTask(model)
|
|
|
|
trainer = Trainer(accelerator="gpu", devices=2)
|
|
trainer.fit(task, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
|
|
|
|
Tasks can be arbitrarily complex such as implementing GAN training, self-supervised or even RL.
|
|
|
|
.. code-block:: python
|
|
|
|
class GANTask(pl.LightningModule):
|
|
def __init__(self, generator, discriminator):
|
|
super().__init__()
|
|
self.generator = generator
|
|
self.discriminator = discriminator
|
|
|
|
...
|
|
|
|
When used like this, the model can be separated from the Task and thus used in production without needing to keep it in
|
|
a ``LightningModule``.
|
|
|
|
The following example shows how you can run inference in the Python runtime:
|
|
|
|
.. code-block:: python
|
|
|
|
task = ClassificationTask(model)
|
|
trainer = Trainer(accelerator="gpu", devices=2)
|
|
trainer.fit(task, train_dataloader, val_dataloader)
|
|
trainer.save_checkpoint("best_model.ckpt")
|
|
|
|
# use model after training or load weights and drop into the production system
|
|
model = ClassificationTask.load_from_checkpoint("best_model.ckpt")
|
|
x = ...
|
|
model.eval()
|
|
with torch.no_grad():
|
|
y_hat = model(x)
|
|
|
|
Check out :ref:`Inference in Production <production_inference>` guide to learn about the possible ways to perform inference in production.
|
|
|
|
|
|
-----------
|
|
|
|
|
|
*************
|
|
Child Modules
|
|
*************
|
|
|
|
.. include:: ../common/child_modules.rst
|
|
|
|
-----------
|
|
|
|
*******************
|
|
LightningModule API
|
|
*******************
|
|
|
|
|
|
Methods
|
|
=======
|
|
|
|
all_gather
|
|
~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.all_gather
|
|
:noindex:
|
|
|
|
configure_callbacks
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_callbacks
|
|
:noindex:
|
|
|
|
configure_optimizers
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_optimizers
|
|
:noindex:
|
|
|
|
forward
|
|
~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.forward
|
|
:noindex:
|
|
|
|
freeze
|
|
~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.freeze
|
|
:noindex:
|
|
|
|
log
|
|
~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.log
|
|
:noindex:
|
|
|
|
log_dict
|
|
~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.log_dict
|
|
:noindex:
|
|
|
|
lr_schedulers
|
|
~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.lr_schedulers
|
|
:noindex:
|
|
|
|
manual_backward
|
|
~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
|
|
:noindex:
|
|
|
|
optimizers
|
|
~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizers
|
|
:noindex:
|
|
|
|
print
|
|
~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.print
|
|
:noindex:
|
|
|
|
predict_step
|
|
~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.predict_step
|
|
:noindex:
|
|
|
|
save_hyperparameters
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.save_hyperparameters
|
|
:noindex:
|
|
|
|
toggle_optimizer
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.toggle_optimizer
|
|
:noindex:
|
|
|
|
test_step
|
|
~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_step
|
|
:noindex:
|
|
|
|
test_step_end
|
|
~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_step_end
|
|
:noindex:
|
|
|
|
test_epoch_end
|
|
~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_epoch_end
|
|
:noindex:
|
|
|
|
to_onnx
|
|
~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.to_onnx
|
|
:noindex:
|
|
|
|
to_torchscript
|
|
~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.to_torchscript
|
|
:noindex:
|
|
|
|
training_step
|
|
~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step
|
|
:noindex:
|
|
|
|
training_step_end
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step_end
|
|
:noindex:
|
|
|
|
training_epoch_end
|
|
~~~~~~~~~~~~~~~~~~
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_epoch_end
|
|
:noindex:
|
|
|
|
unfreeze
|
|
~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.unfreeze
|
|
:noindex:
|
|
|
|
untoggle_optimizer
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.untoggle_optimizer
|
|
:noindex:
|
|
|
|
validation_step
|
|
~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step
|
|
:noindex:
|
|
|
|
validation_step_end
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step_end
|
|
:noindex:
|
|
|
|
validation_epoch_end
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_epoch_end
|
|
:noindex:
|
|
|
|
-----------
|
|
|
|
Properties
|
|
==========
|
|
|
|
These are properties available in a LightningModule.
|
|
|
|
current_epoch
|
|
~~~~~~~~~~~~~
|
|
|
|
The number of epochs run.
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
if self.current_epoch == 0:
|
|
...
|
|
|
|
device
|
|
~~~~~~
|
|
|
|
The device the module is on. Use it to keep your code device agnostic.
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
z = torch.rand(2, 3, device=self.device)
|
|
|
|
global_rank
|
|
~~~~~~~~~~~
|
|
|
|
The ``global_rank`` is the index of the current process across all nodes and devices.
|
|
Lightning will perform some operations such as logging, weight checkpointing only when ``global_rank=0``. You
|
|
usually do not need to use this property, but it is useful to know how to access it if needed.
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
if self.global_rank == 0:
|
|
# do something only once across all the nodes
|
|
...
|
|
|
|
global_step
|
|
~~~~~~~~~~~
|
|
|
|
The number of optimizer steps taken (does not reset each epoch).
|
|
This includes multiple optimizers and TBPTT steps (if enabled).
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
self.logger.experiment.log_image(..., step=self.global_step)
|
|
|
|
hparams
|
|
~~~~~~~
|
|
|
|
The arguments passed through ``LightningModule.__init__()`` and saved by calling
|
|
:meth:`~pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin.save_hyperparameters` could be accessed by the ``hparams`` attribute.
|
|
|
|
.. code-block:: python
|
|
|
|
def __init__(self, learning_rate):
|
|
self.save_hyperparameters()
|
|
|
|
|
|
def configure_optimizers(self):
|
|
return Adam(self.parameters(), lr=self.hparams.learning_rate)
|
|
|
|
logger
|
|
~~~~~~
|
|
|
|
The current logger being used (tensorboard or other supported logger)
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
# the generic logger (same no matter if tensorboard or other supported logger)
|
|
self.logger
|
|
|
|
# the particular logger
|
|
tensorboard_logger = self.logger.experiment
|
|
|
|
loggers
|
|
~~~~~~~
|
|
|
|
The list of loggers currently being used by the Trainer.
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
# List of LightningLoggerBase objects
|
|
loggers = self.loggers
|
|
for logger in loggers:
|
|
logger.log_metrics({"foo": 1.0})
|
|
|
|
local_rank
|
|
~~~~~~~~~~~
|
|
|
|
The ``local_rank`` is the index of the current process across all the devices for the current node.
|
|
You usually do not need to use this property, but it is useful to know how to access it if needed.
|
|
For example, if using 10 machines (or nodes), the GPU at index 0 on each machine has local_rank = 0.
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
if self.local_rank == 0:
|
|
# do something only once across each node
|
|
...
|
|
|
|
precision
|
|
~~~~~~~~~
|
|
|
|
The type of precision used:
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
if self.precision == 16:
|
|
...
|
|
|
|
trainer
|
|
~~~~~~~
|
|
|
|
Pointer to the trainer
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
max_steps = self.trainer.max_steps
|
|
any_flag = self.trainer.any_flag
|
|
|
|
prepare_data_per_node
|
|
~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
If set to ``True`` will call ``prepare_data()`` on LOCAL_RANK=0 for every node.
|
|
If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0.
|
|
|
|
.. testcode::
|
|
|
|
class LitModel(LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.prepare_data_per_node = True
|
|
|
|
automatic_optimization
|
|
~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
When set to ``False``, Lightning does not automate the optimization process. This means you are responsible for handling
|
|
your optimizers. However, we do take care of precision and any accelerators used.
|
|
|
|
See :ref:`manual optimization <common/optimization:Manual optimization>` for details.
|
|
|
|
.. code-block:: python
|
|
|
|
def __init__(self):
|
|
self.automatic_optimization = False
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
opt = self.optimizers(use_pl_optimizer=True)
|
|
|
|
loss = ...
|
|
opt.zero_grad()
|
|
self.manual_backward(loss)
|
|
opt.step()
|
|
|
|
This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note
|
|
that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter.
|
|
Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research.
|
|
|
|
.. code-block:: python
|
|
|
|
def __init__(self):
|
|
self.automatic_optimization = False
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
# access your optimizers with use_pl_optimizer=False. Default is True
|
|
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
|
|
|
|
gen_loss = ...
|
|
opt_a.zero_grad()
|
|
self.manual_backward(gen_loss)
|
|
opt_a.step()
|
|
|
|
disc_loss = ...
|
|
opt_b.zero_grad()
|
|
self.manual_backward(disc_loss)
|
|
opt_b.step()
|
|
|
|
example_input_array
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
Set and access example_input_array, which basically represents a single batch.
|
|
|
|
.. code-block:: python
|
|
|
|
def __init__(self):
|
|
self.example_input_array = ...
|
|
self.generator = ...
|
|
|
|
|
|
def on_train_epoch_end(self):
|
|
# generate some images using the example_input_array
|
|
gen_images = self.generator(self.example_input_array)
|
|
|
|
model_size
|
|
~~~~~~~~~~
|
|
|
|
Get the model file size (in megabytes) using ``self.model_size`` inside LightningModule.
|
|
|
|
truncated_bptt_steps
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of
|
|
a much longer sequence. This is made possible by passing training batches
|
|
split along the time-dimensions into splits of size k to the
|
|
``training_step``. In order to keep the same forward propagation behavior, all
|
|
hidden states should be kept in-between each time-dimension split.
|
|
|
|
|
|
If this is enabled, your batches will automatically get truncated
|
|
and the Trainer will apply Truncated Backprop to it.
|
|
|
|
(`Williams et al. "An efficient gradient-based algorithm for on-line training of
|
|
recurrent network trajectories."
|
|
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
|
|
|
|
`Tutorial <https://d2l.ai/chapter_recurrent-neural-networks/bptt.html>`_
|
|
|
|
.. testcode:: python
|
|
|
|
from pytorch_lightning import LightningModule
|
|
|
|
|
|
class MyModel(LightningModule):
|
|
def __init__(self, input_size, hidden_size, num_layers):
|
|
super().__init__()
|
|
# batch_first has to be set to True
|
|
self.lstm = nn.LSTM(
|
|
input_size=input_size,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers,
|
|
batch_first=True,
|
|
)
|
|
|
|
...
|
|
|
|
# Important: This property activates truncated backpropagation through time
|
|
# Setting this value to 2 splits the batch into sequences of size 2
|
|
self.truncated_bptt_steps = 2
|
|
|
|
# Truncated back-propagation through time
|
|
def training_step(self, batch, batch_idx, hiddens):
|
|
x, y = batch
|
|
|
|
# the training step must be updated to accept a ``hiddens`` argument
|
|
# hiddens are the hiddens from the previous truncated backprop step
|
|
out, hiddens = self.lstm(x, hiddens)
|
|
|
|
...
|
|
|
|
return {"loss": ..., "hiddens": hiddens}
|
|
|
|
Lightning takes care of splitting your batch along the time-dimension. It is
|
|
assumed to be the second dimension of your batches. Therefore, in the
|
|
example above, we have set ``batch_first=True``.
|
|
|
|
.. code-block:: python
|
|
|
|
# we use the second as the time dimension
|
|
# (batch, time, ...)
|
|
sub_batch = batch[0, 0:t, ...]
|
|
|
|
To modify how the batch is split,
|
|
override the :meth:`pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch` method:
|
|
|
|
.. testcode:: python
|
|
|
|
class LitMNIST(LightningModule):
|
|
def tbptt_split_batch(self, batch, split_size):
|
|
# do your own splitting on the batch
|
|
return splits
|
|
|
|
--------------
|
|
|
|
.. _lightning_hooks:
|
|
|
|
Hooks
|
|
=====
|
|
|
|
This is the pseudocode to describe the structure of :meth:`~pytorch_lightning.trainer.Trainer.fit`.
|
|
The inputs and outputs of each function are not represented for simplicity. Please check each function's API reference
|
|
for more information.
|
|
|
|
.. code-block:: python
|
|
|
|
def fit(self):
|
|
if global_rank == 0:
|
|
# prepare data is called on GLOBAL_ZERO only
|
|
prepare_data()
|
|
|
|
configure_callbacks()
|
|
|
|
with parallel(devices):
|
|
# devices can be GPUs, TPUs, ...
|
|
train_on_device(model)
|
|
|
|
|
|
def train_on_device(model):
|
|
# called PER DEVICE
|
|
on_fit_start()
|
|
setup("fit")
|
|
configure_optimizers()
|
|
|
|
# the sanity check runs here
|
|
|
|
on_train_start()
|
|
for epoch in epochs:
|
|
fit_loop()
|
|
on_train_end()
|
|
|
|
on_fit_end()
|
|
teardown("fit")
|
|
|
|
|
|
def fit_loop():
|
|
on_train_epoch_start()
|
|
|
|
for batch in train_dataloader():
|
|
on_train_batch_start()
|
|
|
|
on_before_batch_transfer()
|
|
transfer_batch_to_device()
|
|
on_after_batch_transfer()
|
|
|
|
training_step()
|
|
|
|
on_before_zero_grad()
|
|
optimizer_zero_grad()
|
|
|
|
on_before_backward()
|
|
backward()
|
|
on_after_backward()
|
|
|
|
on_before_optimizer_step()
|
|
configure_gradient_clipping()
|
|
optimizer_step()
|
|
|
|
on_train_batch_end()
|
|
|
|
if should_check_val:
|
|
val_loop()
|
|
# end training epoch
|
|
training_epoch_end()
|
|
|
|
on_train_epoch_end()
|
|
|
|
|
|
def val_loop():
|
|
on_validation_model_eval() # calls `model.eval()`
|
|
torch.set_grad_enabled(False)
|
|
|
|
on_validation_start()
|
|
on_validation_epoch_start()
|
|
|
|
val_outs = []
|
|
for batch_idx, batch in enumerate(val_dataloader()):
|
|
on_validation_batch_start(batch, batch_idx)
|
|
|
|
batch = on_before_batch_transfer(batch)
|
|
batch = transfer_batch_to_device(batch)
|
|
batch = on_after_batch_transfer(batch)
|
|
|
|
out = validation_step(batch, batch_idx)
|
|
|
|
on_validation_batch_end(batch, batch_idx)
|
|
val_outs.append(out)
|
|
|
|
validation_epoch_end(val_outs)
|
|
|
|
on_validation_epoch_end()
|
|
on_validation_end()
|
|
|
|
# set up for train
|
|
on_validation_model_train() # calls `model.train()`
|
|
torch.set_grad_enabled(True)
|
|
|
|
backward
|
|
~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.backward
|
|
:noindex:
|
|
|
|
on_before_backward
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_backward
|
|
:noindex:
|
|
|
|
on_after_backward
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward
|
|
:noindex:
|
|
|
|
on_before_zero_grad
|
|
~~~~~~~~~~~~~~~~~~~
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad
|
|
:noindex:
|
|
|
|
on_fit_start
|
|
~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_fit_start
|
|
:noindex:
|
|
|
|
on_fit_end
|
|
~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_fit_end
|
|
:noindex:
|
|
|
|
|
|
on_load_checkpoint
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint
|
|
:noindex:
|
|
|
|
on_save_checkpoint
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint
|
|
:noindex:
|
|
|
|
load_from_checkpoint
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.load_from_checkpoint
|
|
:noindex:
|
|
|
|
on_hpc_save
|
|
~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_hpc_save
|
|
:noindex:
|
|
|
|
on_hpc_load
|
|
~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_hpc_load
|
|
:noindex:
|
|
|
|
on_train_start
|
|
~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_start
|
|
:noindex:
|
|
|
|
on_train_end
|
|
~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_end
|
|
:noindex:
|
|
|
|
on_validation_start
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_start
|
|
:noindex:
|
|
|
|
on_validation_end
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_end
|
|
:noindex:
|
|
|
|
on_test_batch_start
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_batch_start
|
|
:noindex:
|
|
|
|
on_test_batch_end
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_batch_end
|
|
:noindex:
|
|
|
|
on_test_epoch_start
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_epoch_start
|
|
:noindex:
|
|
|
|
on_test_epoch_end
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_epoch_end
|
|
:noindex:
|
|
|
|
on_test_start
|
|
~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_start
|
|
:noindex:
|
|
|
|
on_test_end
|
|
~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_end
|
|
:noindex:
|
|
|
|
on_predict_batch_start
|
|
~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_batch_start
|
|
:noindex:
|
|
|
|
on_predict_batch_end
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_batch_end
|
|
:noindex:
|
|
|
|
on_predict_epoch_start
|
|
~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_epoch_start
|
|
:noindex:
|
|
|
|
on_predict_epoch_end
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_epoch_end
|
|
:noindex:
|
|
|
|
on_predict_start
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_start
|
|
:noindex:
|
|
|
|
on_predict_end
|
|
~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_end
|
|
:noindex:
|
|
|
|
on_train_batch_start
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_batch_start
|
|
:noindex:
|
|
|
|
on_train_batch_end
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_batch_end
|
|
:noindex:
|
|
|
|
on_train_epoch_start
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_epoch_start
|
|
:noindex:
|
|
|
|
on_train_epoch_end
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_epoch_end
|
|
:noindex:
|
|
|
|
on_validation_batch_start
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_batch_start
|
|
:noindex:
|
|
|
|
on_validation_batch_end
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_batch_end
|
|
:noindex:
|
|
|
|
on_validation_epoch_start
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_epoch_start
|
|
:noindex:
|
|
|
|
on_validation_epoch_end
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_epoch_end
|
|
:noindex:
|
|
|
|
on_post_move_to_device
|
|
~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_post_move_to_device
|
|
:noindex:
|
|
|
|
configure_sharded_model
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_sharded_model
|
|
:noindex:
|
|
|
|
on_validation_model_eval
|
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_model_eval
|
|
:noindex:
|
|
|
|
on_validation_model_train
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_model_train
|
|
:noindex:
|
|
|
|
on_test_model_eval
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_model_eval
|
|
:noindex:
|
|
|
|
on_test_model_train
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_model_train
|
|
:noindex:
|
|
|
|
on_before_optimizer_step
|
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_optimizer_step
|
|
:noindex:
|
|
|
|
configure_gradient_clipping
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping
|
|
:noindex:
|
|
|
|
optimizer_step
|
|
~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_step
|
|
:noindex:
|
|
|
|
optimizer_zero_grad
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad
|
|
:noindex:
|
|
|
|
prepare_data
|
|
~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.prepare_data
|
|
:noindex:
|
|
|
|
setup
|
|
~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.setup
|
|
:noindex:
|
|
|
|
tbptt_split_batch
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch
|
|
:noindex:
|
|
|
|
teardown
|
|
~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.teardown
|
|
:noindex:
|
|
|
|
train_dataloader
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader
|
|
:noindex:
|
|
|
|
val_dataloader
|
|
~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader
|
|
:noindex:
|
|
|
|
test_dataloader
|
|
~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader
|
|
:noindex:
|
|
|
|
predict_dataloader
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.predict_dataloader
|
|
:noindex:
|
|
|
|
on_train_dataloader
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_dataloader
|
|
:noindex:
|
|
|
|
on_val_dataloader
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_val_dataloader
|
|
:noindex:
|
|
|
|
on_test_dataloader
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_dataloader
|
|
:noindex:
|
|
|
|
on_predict_dataloader
|
|
~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_dataloader
|
|
:noindex:
|
|
|
|
transfer_batch_to_device
|
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device
|
|
:noindex:
|
|
|
|
on_before_batch_transfer
|
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_batch_transfer
|
|
:noindex:
|
|
|
|
on_after_batch_transfer
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_batch_transfer
|
|
:noindex:
|
|
|
|
add_to_queue
|
|
~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.add_to_queue
|
|
:noindex:
|
|
|
|
get_from_queue
|
|
~~~~~~~~~~~~~~
|
|
|
|
.. automethod:: pytorch_lightning.core.lightning.LightningModule.get_from_queue
|
|
:noindex:
|