2020-01-21 20:18:32 +00:00
|
|
|
.. role:: hidden
|
|
|
|
:class: hidden-section
|
2020-09-14 01:04:21 +00:00
|
|
|
|
|
|
|
.. _lightning_module:
|
2020-01-21 20:18:32 +00:00
|
|
|
|
|
|
|
LightningModule
|
2020-02-09 22:39:10 +00:00
|
|
|
===============
|
2020-08-11 23:39:43 +00:00
|
|
|
A :class:`~LightningModule` organizes your PyTorch code into 5 sections
|
2020-02-09 22:39:10 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
- Computations (init).
|
|
|
|
- Train loop (training_step)
|
|
|
|
- Validation loop (validation_step)
|
|
|
|
- Test loop (test_step)
|
|
|
|
- Optimizers (configure_optimizers)
|
2020-01-21 20:18:32 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
|
|
2020-08-13 22:52:47 +00:00
|
|
|
.. raw:: html
|
|
|
|
|
2020-10-08 19:54:52 +00:00
|
|
|
<video width="50%" max-width="400px" controls autoplay muted playsinline src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v"></video>
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Notice a few things.
|
|
|
|
|
|
|
|
1. It's the SAME code.
|
|
|
|
2. The PyTorch code IS NOT abstracted - just organized.
|
|
|
|
3. All the other code that's not in the :class:`~LightningModule`
|
|
|
|
has been automated for you by the trainer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
net = Net()
|
|
|
|
trainer = Trainer()
|
|
|
|
trainer.fit(net)
|
|
|
|
|
|
|
|
4. There are no .cuda() or .to() calls... Lightning does these for you.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# don't do in lightning
|
|
|
|
x = torch.Tensor(2, 3)
|
|
|
|
x = x.cuda()
|
|
|
|
x = x.to(device)
|
|
|
|
|
|
|
|
# do this instead
|
|
|
|
x = x # leave it alone!
|
|
|
|
|
|
|
|
# or to init a new tensor
|
|
|
|
new_x = torch.Tensor(2, 3)
|
2020-08-18 21:51:38 +00:00
|
|
|
new_x = new_x.type_as(x)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
5. There are no samplers for distributed, Lightning also does this for you.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# Don't do in Lightning...
|
|
|
|
data = MNIST(...)
|
|
|
|
sampler = DistributedSampler(data)
|
|
|
|
DataLoader(data, sampler=sampler)
|
|
|
|
|
|
|
|
# do this instead
|
|
|
|
data = MNIST(...)
|
|
|
|
DataLoader(data)
|
|
|
|
|
|
|
|
6. A :class:`~LightningModule` is a :class:`torch.nn.Module` but with added functionality. Use it as such!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
net = Net.load_from_checkpoint(PATH)
|
|
|
|
net.freeze()
|
|
|
|
out = net(x)
|
|
|
|
|
|
|
|
Thus, to use Lightning, you just need to organize your code which takes about 30 minutes,
|
|
|
|
(and let's be real, you probably should do anyhow).
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
|
|
|
Minimal Example
|
|
|
|
---------------
|
|
|
|
|
|
|
|
Here are the only required methods.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
>>> import pytorch_lightning as pl
|
|
|
|
>>> class LitModel(pl.LightningModule):
|
|
|
|
...
|
|
|
|
... def __init__(self):
|
|
|
|
... super().__init__()
|
|
|
|
... self.l1 = torch.nn.Linear(28 * 28, 10)
|
|
|
|
...
|
|
|
|
... def forward(self, x):
|
|
|
|
... return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
...
|
|
|
|
... def training_step(self, batch, batch_idx):
|
|
|
|
... x, y = batch
|
|
|
|
... y_hat = self(x)
|
|
|
|
... loss = F.cross_entropy(y_hat, y)
|
2020-09-30 12:31:16 +00:00
|
|
|
... return loss
|
2020-08-11 23:39:43 +00:00
|
|
|
...
|
|
|
|
... def configure_optimizers(self):
|
|
|
|
... return torch.optim.Adam(self.parameters(), lr=0.02)
|
|
|
|
|
|
|
|
Which you can train by doing:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
|
|
|
|
trainer = pl.Trainer()
|
|
|
|
model = LitModel()
|
|
|
|
|
|
|
|
trainer.fit(model, train_loader)
|
|
|
|
|
|
|
|
----------
|
|
|
|
|
|
|
|
LightningModule for research
|
|
|
|
----------------------------
|
|
|
|
For research, LightningModules are best structured as systems.
|
|
|
|
|
|
|
|
A model (colloquially) refers to something like a resnet or RNN. A system, may be a collection of models. Here
|
|
|
|
are examples of systems:
|
|
|
|
|
|
|
|
- GAN (generator, discriminator)
|
|
|
|
- RL (policy, actor, critic)
|
|
|
|
- Autoencoders (encoder, decoder)
|
|
|
|
- Seq2Seq (encoder, attention, decoder)
|
|
|
|
- etc...
|
|
|
|
|
|
|
|
A LightningModule is best used to define a complex system:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
class Autoencoder(pl.LightningModule):
|
|
|
|
|
|
|
|
def __init__(self, latent_dim=2):
|
|
|
|
super().__init__()
|
|
|
|
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
|
|
|
|
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, _ = batch
|
|
|
|
|
|
|
|
# encode
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
z = self.encoder(x)
|
|
|
|
|
|
|
|
# decode
|
|
|
|
recons = self.decoder(z)
|
|
|
|
|
|
|
|
# reconstruction
|
|
|
|
reconstruction_loss = nn.functional.mse_loss(recons, x)
|
2020-09-30 12:31:16 +00:00
|
|
|
return reconstruction_loss
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
x, _ = batch
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
z = self.encoder(x)
|
|
|
|
recons = self.decoder(z)
|
|
|
|
reconstruction_loss = nn.functional.mse_loss(recons, x)
|
2020-09-30 12:31:16 +00:00
|
|
|
self.log('val_reconstruction', reconstruction_loss)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return torch.optim.Adam(self.parameters(), lr=0.0002)
|
|
|
|
|
|
|
|
Which can be trained like this:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
autoencoder = Autoencoder()
|
|
|
|
trainer = pl.Trainer(gpus=1)
|
|
|
|
trainer.fit(autoencoder, train_dataloader, val_dataloader)
|
|
|
|
|
|
|
|
This simple model generates examples that look like this (the encoders and decoders are too weak)
|
|
|
|
|
|
|
|
.. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/ae_docs.png
|
|
|
|
:width: 300
|
|
|
|
|
|
|
|
The methods above are part of the lightning interface:
|
|
|
|
|
|
|
|
- training_step
|
|
|
|
- validation_step
|
|
|
|
- test_step
|
|
|
|
- configure_optimizers
|
|
|
|
|
|
|
|
Note that in this case, the train loop and val loop are exactly the same. We can of course reuse this code.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class Autoencoder(pl.LightningModule):
|
|
|
|
|
|
|
|
def __init__(self, latent_dim=2):
|
|
|
|
super().__init__()
|
|
|
|
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
|
|
|
|
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
loss = self.shared_step(batch)
|
2020-10-07 21:55:24 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
return loss
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
loss = self.shared_step(batch)
|
2020-09-30 12:31:16 +00:00
|
|
|
self.log('val_loss', loss)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def shared_step(self, batch):
|
|
|
|
x, _ = batch
|
|
|
|
|
|
|
|
# encode
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
z = self.encoder(x)
|
|
|
|
|
|
|
|
# decode
|
|
|
|
recons = self.decoder(z)
|
|
|
|
|
|
|
|
# loss
|
|
|
|
return nn.functional.mse_loss(recons, x)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return torch.optim.Adam(self.parameters(), lr=0.0002)
|
|
|
|
|
|
|
|
We create a new method called `shared_step` that all loops can use. This method name is arbitrary and NOT reserved.
|
|
|
|
|
|
|
|
Inference in Research
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
In the case where we want to perform inference with the system we can add a `forward` method to the LightningModule.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class Autoencoder(pl.LightningModule):
|
|
|
|
def forward(self, x):
|
|
|
|
return self.decoder(x)
|
|
|
|
|
|
|
|
The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure,
|
|
|
|
such as text generation:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class Seq2Seq(pl.LightningModule):
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
embeddings = self(x)
|
|
|
|
hidden_states = self.encoder(embeddings)
|
|
|
|
for h in hidden_states:
|
|
|
|
# decode
|
|
|
|
...
|
|
|
|
return decoded
|
|
|
|
|
|
|
|
---------------------
|
|
|
|
|
|
|
|
LightningModule for production
|
|
|
|
------------------------------
|
|
|
|
For cases like production, you might want to iterate different models inside a LightningModule.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
from pytorch_lightning.metrics import functional as FM
|
|
|
|
|
|
|
|
class ClassificationTask(pl.LightningModule):
|
|
|
|
|
|
|
|
def __init__(self, model):
|
|
|
|
super().__init__()
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
y_hat = self.model(x)
|
|
|
|
loss = F.cross_entropy(y_hat, y)
|
2020-09-30 12:31:16 +00:00
|
|
|
return loss
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
y_hat = self.model(x)
|
|
|
|
loss = F.cross_entropy(y_hat, y)
|
|
|
|
acc = FM.accuracy(y_hat, y)
|
2020-09-21 02:58:43 +00:00
|
|
|
|
|
|
|
# loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
|
2020-09-30 12:31:16 +00:00
|
|
|
metrics = {'val_acc': acc, 'val_loss': loss}
|
|
|
|
self.log_dict(metrics)
|
|
|
|
return metrics
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def test_step(self, batch, batch_idx):
|
2020-09-30 12:31:16 +00:00
|
|
|
metrics = self.validation_step(batch, batch_idx)
|
|
|
|
metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']}
|
|
|
|
self.log_dict(metrics)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return torch.optim.Adam(self.model.parameters(), lr=0.02)
|
|
|
|
|
|
|
|
Then pass in any arbitrary model to be fit with this task
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
for model in [resnet50(), vgg16(), BidirectionalRNN()]:
|
|
|
|
task = ClassificationTask(model)
|
|
|
|
|
|
|
|
trainer = Trainer(gpus=2)
|
|
|
|
trainer.fit(task, train_dataloader, val_dataloader)
|
|
|
|
|
|
|
|
Tasks can be arbitrarily complex such as implementing GAN training, self-supervised or even RL.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class GANTask(pl.LightningModule):
|
|
|
|
|
|
|
|
def __init__(self, generator, discriminator):
|
|
|
|
super().__init__()
|
|
|
|
self.generator = generator
|
|
|
|
self.discriminator = discriminator
|
|
|
|
...
|
|
|
|
|
|
|
|
Inference in production
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
When used like this, the model can be separated from the Task and thus used in production without needing to keep it in
|
|
|
|
a `LightningModule`.
|
|
|
|
|
|
|
|
- You can export to onnx.
|
|
|
|
- Or trace using Jit.
|
|
|
|
- or run in the python runtime.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
task = ClassificationTask(model)
|
|
|
|
|
|
|
|
trainer = Trainer(gpus=2)
|
|
|
|
trainer.fit(task, train_dataloader, val_dataloader)
|
|
|
|
|
|
|
|
# use model after training or load weights and drop into the production system
|
|
|
|
model.eval()
|
|
|
|
y_hat = model(x)
|
|
|
|
|
|
|
|
|
|
|
|
Training loop
|
|
|
|
-------------
|
|
|
|
To add a training loop use the `training_step` method
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class LitClassifier(pl.LightningModule):
|
|
|
|
|
|
|
|
def __init__(self, model):
|
|
|
|
super().__init__()
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
y_hat = self.model(x)
|
|
|
|
loss = F.cross_entropy(y_hat, y)
|
2020-09-30 12:31:16 +00:00
|
|
|
return loss
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
Under the hood, Lightning does the following (pseudocode):
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# put model in train mode
|
|
|
|
model.train()
|
|
|
|
torch.set_grad_enabled(True)
|
|
|
|
|
|
|
|
outs = []
|
|
|
|
for batch in train_dataloader:
|
|
|
|
# forward
|
|
|
|
out = training_step(val_batch)
|
|
|
|
|
|
|
|
# backward
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
# apply and clear grads
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
Training epoch-level metrics
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
2020-09-30 12:31:16 +00:00
|
|
|
If you want to calculate epoch-level metrics and log them, use the `.log` method
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
y_hat = self.model(x)
|
|
|
|
loss = F.cross_entropy(y_hat, y)
|
|
|
|
|
|
|
|
# logs metrics for each training_step, and the average across the epoch, to the progress bar and logger
|
2020-09-30 12:41:24 +00:00
|
|
|
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
2020-09-30 12:31:16 +00:00
|
|
|
return loss
|
2020-08-11 23:39:43 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
The `.log` object automatically reduces the requested metrics across the full epoch.
|
2020-08-11 23:39:43 +00:00
|
|
|
Here's the pseudocode of what it does under the hood:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
outs = []
|
|
|
|
for batch in train_dataloader:
|
|
|
|
# forward
|
|
|
|
out = training_step(val_batch)
|
|
|
|
|
|
|
|
# backward
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
# apply and clear grads
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs]))
|
|
|
|
|
|
|
|
Train epoch-level operations
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
If you need to do something with all the outputs of each `training_step`, override `training_epoch_end` yourself.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
y_hat = self.model(x)
|
|
|
|
loss = F.cross_entropy(y_hat, y)
|
2020-09-30 12:31:16 +00:00
|
|
|
preds = ...
|
|
|
|
return {'loss': loss, 'other_stuff': preds}
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def training_epoch_end(self, training_step_outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
for pred in training_step_outputs:
|
|
|
|
# do something
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
The matching pseudocode is:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
outs = []
|
|
|
|
for batch in train_dataloader:
|
|
|
|
# forward
|
|
|
|
out = training_step(val_batch)
|
|
|
|
|
|
|
|
# backward
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
# apply and clear grads
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
2020-10-05 03:34:27 +00:00
|
|
|
training_epoch_end(outs)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
Training with DataParallel
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
When training using a `distributed_backend` that splits data from each batch across GPUs, sometimes you might
|
|
|
|
need to aggregate them on the master GPU for processing (dp, or ddp2).
|
|
|
|
|
|
|
|
In this case, implement the `training_step_end` method
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
y_hat = self.model(x)
|
|
|
|
loss = F.cross_entropy(y_hat, y)
|
2020-09-30 12:31:16 +00:00
|
|
|
pred = ...
|
|
|
|
return {'loss': loss, 'pred': pred}
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def training_step_end(self, batch_parts):
|
2020-09-30 12:31:16 +00:00
|
|
|
gpu_0_prediction = batch_parts.pred[0]['pred']
|
|
|
|
gpu_1_prediction = batch_parts.pred[1]['pred']
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
# do something with both outputs
|
2020-09-30 12:31:16 +00:00
|
|
|
return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def training_epoch_end(self, training_step_outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
for out in training_step_outputs:
|
|
|
|
# do something with preds
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
The full pseudocode that lighting does under the hood is:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
outs = []
|
|
|
|
for train_batch in train_dataloader:
|
|
|
|
batches = split_batch(train_batch)
|
|
|
|
dp_outs = []
|
|
|
|
for sub_batch in batches:
|
|
|
|
# 1
|
|
|
|
dp_out = training_step(sub_batch)
|
|
|
|
dp_outs.append(dp_out)
|
|
|
|
|
|
|
|
# 2
|
|
|
|
out = training_step_end(dp_outs)
|
|
|
|
outs.append(out)
|
|
|
|
|
|
|
|
# do something with the outputs for all batches
|
|
|
|
# 3
|
|
|
|
training_epoch_end(outs)
|
|
|
|
|
|
|
|
------------------
|
|
|
|
|
|
|
|
Validation loop
|
|
|
|
---------------
|
|
|
|
To add a validation loop, override the `validation_step` method of the :class:`~LightningModule`:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class LitModel(pl.LightningModule):
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
y_hat = self.model(x)
|
|
|
|
loss = F.cross_entropy(y_hat, y)
|
2020-09-30 12:31:16 +00:00
|
|
|
self.log('val_loss', loss)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
Under the hood, Lightning does the following:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# ...
|
|
|
|
for batch in train_dataloader:
|
|
|
|
loss = model.training_step()
|
|
|
|
loss.backward()
|
|
|
|
# ...
|
|
|
|
|
|
|
|
if validate_at_some_point:
|
|
|
|
# disable grads + batchnorm + dropout
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
# ----------------- VAL LOOP ---------------
|
|
|
|
for val_batch in model.val_dataloader:
|
|
|
|
val_out = model.validation_step(val_batch)
|
|
|
|
# ----------------- VAL LOOP ---------------
|
|
|
|
|
|
|
|
# enable grads + batchnorm + dropout
|
|
|
|
torch.set_grad_enabled(True)
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
Validation epoch-level metrics
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
If you need to do something with all the outputs of each `validation_step`, override `validation_epoch_end`.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
y_hat = self.model(x)
|
|
|
|
loss = F.cross_entropy(y_hat, y)
|
2020-09-30 12:31:16 +00:00
|
|
|
pred = ...
|
|
|
|
return pred
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def validation_epoch_end(self, validation_step_outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
for pred in validation_step_outputs:
|
|
|
|
# do something with a pred
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
Validating with DataParallel
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
When training using a `distributed_backend` that splits data from each batch across GPUs, sometimes you might
|
|
|
|
need to aggregate them on the master GPU for processing (dp, or ddp2).
|
|
|
|
|
|
|
|
In this case, implement the `validation_step_end` method
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
y_hat = self.model(x)
|
|
|
|
loss = F.cross_entropy(y_hat, y)
|
2020-09-30 12:31:16 +00:00
|
|
|
pred = ...
|
|
|
|
return {'loss': loss, 'pred': pred}
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def validation_step_end(self, batch_parts):
|
2020-09-30 12:31:16 +00:00
|
|
|
gpu_0_prediction = batch_parts.pred[0]['pred']
|
|
|
|
gpu_1_prediction = batch_parts.pred[1]['pred']
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
# do something with both outputs
|
2020-09-30 12:31:16 +00:00
|
|
|
return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
def validation_epoch_end(self, validation_step_outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
for out in validation_step_outputs:
|
|
|
|
# do something with preds
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
The full pseudocode that lighting does under the hood is:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
outs = []
|
|
|
|
for batch in dataloader:
|
|
|
|
batches = split_batch(batch)
|
|
|
|
dp_outs = []
|
|
|
|
for sub_batch in batches:
|
|
|
|
# 1
|
|
|
|
dp_out = validation_step(sub_batch)
|
|
|
|
dp_outs.append(dp_out)
|
|
|
|
|
|
|
|
# 2
|
|
|
|
out = validation_step_end(dp_outs)
|
|
|
|
outs.append(out)
|
|
|
|
|
|
|
|
# do something with the outputs for all batches
|
|
|
|
# 3
|
|
|
|
validation_epoch_end(outs)
|
|
|
|
|
|
|
|
----------------
|
|
|
|
|
|
|
|
Test loop
|
|
|
|
---------
|
|
|
|
The process for adding a test loop is the same as the process for adding a validation loop. Please refer to
|
|
|
|
the section above for details.
|
|
|
|
|
|
|
|
The only difference is that the test loop is only called when `.test()` is used:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model = Model()
|
|
|
|
trainer = Trainer()
|
|
|
|
trainer.fit()
|
|
|
|
|
|
|
|
# automatically loads the best weights for you
|
|
|
|
trainer.test(model)
|
|
|
|
|
|
|
|
There are two ways to call `test()`:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# call after training
|
|
|
|
trainer = Trainer()
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
# automatically auto-loads the best weights
|
|
|
|
trainer.test(test_dataloaders=test_dataloader)
|
|
|
|
|
|
|
|
# or call with pretrained model
|
|
|
|
model = MyLightningModule.load_from_checkpoint(PATH)
|
|
|
|
trainer = Trainer()
|
|
|
|
trainer.test(model, test_dataloaders=test_dataloader)
|
|
|
|
|
|
|
|
----------
|
|
|
|
|
|
|
|
Live demo
|
|
|
|
---------
|
|
|
|
Check out this
|
|
|
|
`COLAB <https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg>`_
|
|
|
|
for a live demo.
|
|
|
|
|
|
|
|
-----------
|
|
|
|
|
|
|
|
LightningModule API
|
|
|
|
-------------------
|
|
|
|
|
|
|
|
Training loop methods
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
|
|
|
|
training_step
|
|
|
|
~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.training_step
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
training_step_end
|
|
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.training_step_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
training_epoch_end
|
|
|
|
~~~~~~~~~~~~~~~~~~
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.training_epoch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
---------------
|
|
|
|
|
|
|
|
Validation loop methods
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
|
|
|
|
validation_step
|
|
|
|
~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.validation_step
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
validation_step_end
|
|
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.validation_step_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
validation_epoch_end
|
|
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.validation_epoch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
----------------
|
|
|
|
|
|
|
|
test loop methods
|
|
|
|
^^^^^^^^^^^^^^^^^
|
|
|
|
|
|
|
|
test_step
|
|
|
|
~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.test_step
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
test_step_end
|
|
|
|
~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.test_step_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
test_epoch_end
|
|
|
|
~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.test_epoch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
--------------
|
|
|
|
|
|
|
|
configure_optimizers
|
|
|
|
^^^^^^^^^^^^^^^^^^^^
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.configure_optimizers
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
--------------
|
|
|
|
|
|
|
|
Convenience methods
|
|
|
|
^^^^^^^^^^^^^^^^^^^
|
|
|
|
Use these methods for convenience
|
|
|
|
|
|
|
|
print
|
|
|
|
~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.print
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
save_hyperparameters
|
|
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.save_hyperparameters
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
Logging methods
|
|
|
|
^^^^^^^^^^^^^^^
|
|
|
|
Use these methods to interact with the loggers
|
|
|
|
|
|
|
|
log
|
|
|
|
~~~
|
|
|
|
|
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.log
|
|
|
|
:noindex:
|
|
|
|
|
|
|
|
log_dict
|
|
|
|
~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.log_dict
|
|
|
|
:noindex:
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
Inference methods
|
|
|
|
^^^^^^^^^^^^^^^^^
|
|
|
|
Use these hooks for inference with a lightning module
|
|
|
|
|
|
|
|
forward
|
|
|
|
~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.forward
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
freeze
|
|
|
|
~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.freeze
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
to_onnx
|
|
|
|
~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_onnx
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
2020-09-03 18:24:44 +00:00
|
|
|
to_torchscript
|
|
|
|
~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_torchscript
|
2020-09-03 18:24:44 +00:00
|
|
|
:noindex:
|
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
unfreeze
|
|
|
|
~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.unfreeze
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
|
|
|
Properties
|
|
|
|
^^^^^^^^^^
|
|
|
|
These are properties available in a LightningModule.
|
|
|
|
|
|
|
|
-----------
|
|
|
|
|
|
|
|
current_epoch
|
|
|
|
~~~~~~~~~~~~~
|
|
|
|
The current epoch
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(...):
|
|
|
|
if self.current_epoch == 0:
|
|
|
|
|
|
|
|
-------------
|
|
|
|
|
|
|
|
device
|
|
|
|
~~~~~~
|
|
|
|
The device the module is on. Use it to keep your code device agnostic
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(...):
|
|
|
|
z = torch.rand(2, 3, device=self.device)
|
|
|
|
|
|
|
|
-------------
|
|
|
|
|
|
|
|
global_rank
|
|
|
|
~~~~~~~~~~~
|
|
|
|
The global_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You
|
|
|
|
normally do not need to use this property
|
|
|
|
|
|
|
|
Global rank refers to the index of that GPU across ALL GPUs. For example, if using 10 machines, each with 4 GPUs,
|
|
|
|
the 4th GPU on the 10th machine has global_rank = 39
|
|
|
|
|
|
|
|
-------------
|
|
|
|
|
|
|
|
global_step
|
|
|
|
~~~~~~~~~~~
|
|
|
|
The current step (does not reset each epoch)
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(...):
|
|
|
|
self.logger.experiment.log_image(..., step=self.global_step)
|
|
|
|
|
|
|
|
-------------
|
|
|
|
|
|
|
|
hparams
|
|
|
|
~~~~~~~
|
|
|
|
After calling `save_hyperparameters` anything passed to init() is available via hparams.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def __init__(self, learning_rate):
|
|
|
|
self.save_hyperparameters()
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return Adam(self.parameters(), lr=self.hparams.learning_rate)
|
|
|
|
|
|
|
|
--------------
|
|
|
|
|
|
|
|
logger
|
|
|
|
~~~~~~
|
|
|
|
The current logger being used (tensorboard or other supported logger)
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(...):
|
|
|
|
# the generic logger (same no matter if tensorboard or other supported logger)
|
|
|
|
self.logger
|
|
|
|
|
|
|
|
# the particular logger
|
|
|
|
tensorboard_logger = self.logger.experiment
|
|
|
|
|
|
|
|
--------------
|
|
|
|
|
|
|
|
local_rank
|
|
|
|
~~~~~~~~~~~
|
|
|
|
The local_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You
|
|
|
|
normally do not need to use this property
|
|
|
|
|
|
|
|
Local rank refers to the rank on that machine. For example, if using 10 machines, the GPU at index 0 on each machine
|
|
|
|
has local_rank = 0.
|
|
|
|
|
|
|
|
|
|
|
|
-----------
|
|
|
|
|
|
|
|
precision
|
|
|
|
~~~~~~~~~
|
|
|
|
The type of precision used:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(...):
|
|
|
|
if self.precision == 16:
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
|
|
|
trainer
|
|
|
|
~~~~~~~
|
|
|
|
Pointer to the trainer
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(...):
|
|
|
|
max_steps = self.trainer.max_steps
|
|
|
|
any_flag = self.trainer.any_flag
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
2020-09-14 08:35:14 +00:00
|
|
|
use_amp
|
|
|
|
~~~~~~~
|
|
|
|
True if using Automatic Mixed Precision (AMP)
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
use_ddp
|
|
|
|
~~~~~~~
|
|
|
|
True if using ddp
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
|
|
|
use_ddp2
|
|
|
|
~~~~~~~~
|
|
|
|
True if using ddp2
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
|
|
|
use_dp
|
|
|
|
~~~~~~
|
|
|
|
True if using dp
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
|
|
|
use_tpu
|
|
|
|
~~~~~~~
|
|
|
|
True if using TPUs
|
|
|
|
|
|
|
|
--------------
|
|
|
|
|
|
|
|
Hooks
|
|
|
|
-----
|
|
|
|
|
|
|
|
Hook lifecycle pseudocode
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
This is the pseudocode to describe how all the hooks are called during a call to `.fit()`
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def fit(...):
|
|
|
|
on_fit_start()
|
|
|
|
|
|
|
|
if global_rank == 0:
|
|
|
|
# prepare data is called on GLOBAL_ZERO only
|
|
|
|
prepare_data()
|
|
|
|
|
|
|
|
for gpu/tpu in gpu/tpus:
|
|
|
|
train_on_device(model.copy())
|
|
|
|
|
|
|
|
on_fit_end()
|
|
|
|
|
|
|
|
def train_on_device(model):
|
|
|
|
# setup is called PER DEVICE
|
|
|
|
setup()
|
|
|
|
configure_optimizers()
|
|
|
|
on_pretrain_routine_start()
|
|
|
|
|
|
|
|
for epoch in epochs:
|
|
|
|
train_loop()
|
|
|
|
|
|
|
|
teardown()
|
|
|
|
|
|
|
|
def train_loop():
|
|
|
|
on_train_epoch_start()
|
|
|
|
train_outs = []
|
|
|
|
for train_batch in train_dataloader():
|
|
|
|
on_train_batch_start()
|
|
|
|
|
|
|
|
# ----- train_step methods -------
|
|
|
|
out = training_step(batch)
|
|
|
|
train_outs.append(out)
|
|
|
|
|
|
|
|
loss = out.loss
|
|
|
|
|
|
|
|
backward()
|
|
|
|
on_after_backward()
|
|
|
|
optimizer_step()
|
|
|
|
on_before_zero_grad()
|
|
|
|
optimizer_zero_grad()
|
|
|
|
|
2020-10-08 01:48:38 +00:00
|
|
|
on_train_batch_end(out)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
if should_check_val:
|
|
|
|
val_loop()
|
|
|
|
|
|
|
|
# end training epoch
|
|
|
|
logs = training_epoch_end(outs)
|
|
|
|
|
|
|
|
def val_loop():
|
|
|
|
model.eval()
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
|
|
on_validation_epoch_start()
|
|
|
|
val_outs = []
|
|
|
|
for val_batch in val_dataloader():
|
|
|
|
on_validation_batch_start()
|
|
|
|
|
|
|
|
# -------- val step methods -------
|
|
|
|
out = validation_step(val_batch)
|
|
|
|
val_outs.append(out)
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
on_validation_batch_end(out)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
validation_epoch_end(val_outs)
|
|
|
|
on_validation_epoch_end()
|
|
|
|
|
|
|
|
# set up for train
|
|
|
|
model.train()
|
|
|
|
torch.set_grad_enabled(True)
|
|
|
|
|
|
|
|
|
|
|
|
Advanced hooks
|
|
|
|
^^^^^^^^^^^^^^
|
|
|
|
Use these hooks to modify advanced functionality
|
|
|
|
|
|
|
|
|
|
|
|
get_progress_bar_dict
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
tbptt_split_batch
|
|
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
Checkpoint hooks
|
|
|
|
^^^^^^^^^^^^^^^^
|
|
|
|
These hooks allow you to modify checkpoints
|
|
|
|
|
|
|
|
on_load_checkpoint
|
|
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_save_checkpoint
|
|
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
-------------
|
|
|
|
|
|
|
|
Data hooks
|
|
|
|
^^^^^^^^^^
|
|
|
|
Use these hooks if you want to couple a LightningModule to a dataset.
|
|
|
|
|
|
|
|
.. note:: The same collection of hooks is available in a DataModule class to decouple the data from the model.
|
|
|
|
|
|
|
|
train_dataloader
|
|
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.train_dataloader
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
val_dataloader
|
|
|
|
~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.val_dataloader
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
test_dataloader
|
|
|
|
~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.test_dataloader
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
prepare_data
|
|
|
|
~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.prepare_data
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
------------
|
|
|
|
|
|
|
|
Optimization hooks
|
|
|
|
^^^^^^^^^^^^^^^^^^
|
|
|
|
These are hooks related to the optimization procedure.
|
|
|
|
|
|
|
|
backward
|
|
|
|
~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.backward
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_after_backward
|
|
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.on_after_backward
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_before_zero_grad
|
|
|
|
~~~~~~~~~~~~~~~~~~~
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
optimizer_step
|
|
|
|
~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.optimizer_step
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
optimizer_zero_grad
|
|
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
Training lifecycle hooks
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
These hooks are called during training
|
|
|
|
|
|
|
|
on_fit_start
|
|
|
|
~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_fit_start
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_fit_end
|
|
|
|
~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_fit_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_pretrain_routine_start
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_start
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_pretrain_routine_end
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
2020-09-14 08:35:14 +00:00
|
|
|
on_test_batch_start
|
2020-08-11 23:39:43 +00:00
|
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_test_batch_start
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
2020-09-14 08:35:14 +00:00
|
|
|
on_test_batch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_test_batch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
2020-09-14 08:35:14 +00:00
|
|
|
on_test_epoch_start
|
2020-08-11 23:39:43 +00:00
|
|
|
~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_start
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
2020-09-14 08:35:14 +00:00
|
|
|
on_test_epoch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_train_batch_start
|
|
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_train_batch_end
|
|
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_train_epoch_start
|
|
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_start
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_train_epoch_end
|
|
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_validation_batch_start
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_start
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_validation_batch_end
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_validation_epoch_start
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_start
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
on_validation_epoch_end
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
setup
|
|
|
|
~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.setup
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
teardown
|
|
|
|
~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.teardown
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|
|
|
|
|
|
|
|
transfer_batch_to_device
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2020-10-08 00:41:56 +00:00
|
|
|
.. autofunction:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device
|
2020-08-11 23:39:43 +00:00
|
|
|
:noindex:
|