Update `docs/source-pytorch/common/lightning_module.rst` (#18451)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
451226508c
commit
a0133863a3
|
@ -85,38 +85,42 @@ Here are the only required methods.
|
|||
.. code-block:: python
|
||||
|
||||
import lightning.pytorch as pl
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
from lightning.pytorch.demos import Transformer
|
||||
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self):
|
||||
class LightningTransformer(pl.LightningModule):
|
||||
def __init__(self, vocab_size):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(28 * 28, 10)
|
||||
self.model = Transformer(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
||||
def forward(self, inputs, target):
|
||||
return self.model(inputs, target)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
inputs, target = batch
|
||||
output = self(inputs, target)
|
||||
loss = torch.nn.functional.nll_loss(output, target.view(-1))
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
return torch.optim.SGD(self.model.parameters(), lr=0.1)
|
||||
|
||||
Which you can train by doing:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
|
||||
trainer = pl.Trainer(max_epochs=1)
|
||||
model = LitModel()
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
trainer.fit(model, train_dataloaders=train_loader)
|
||||
dataset = pl.demos.WikiText2()
|
||||
dataloader = DataLoader(dataset)
|
||||
model = LightningTransformer(vocab_size=dataset.vocab_size)
|
||||
|
||||
The LightningModule has many convenience methods, but the core ones you need to know about are:
|
||||
trainer = pl.Trainer(fast_dev_run=100)
|
||||
trainer.fit(model=model, train_dataloaders=dataloader)
|
||||
|
||||
The LightningModule has many convenient methods, but the core ones you need to know about are:
|
||||
|
||||
.. list-table::
|
||||
:widths: 50 50
|
||||
|
@ -152,15 +156,15 @@ To activate the training loop, override the :meth:`~lightning.pytorch.core.modul
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
class LitClassifier(pl.LightningModule):
|
||||
def __init__(self, model):
|
||||
class LightningTransformer(pl.LightningModule):
|
||||
def __init__(self, vocab_size):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model = Transformer(vocab_size=vocab_size)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
inputs, target = batch
|
||||
output = self.model(inputs, target)
|
||||
loss = torch.nn.functional.nll_loss(output, target.view(-1))
|
||||
return loss
|
||||
|
||||
Under the hood, Lightning does the following (pseudocode):
|
||||
|
@ -191,15 +195,15 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~lightning
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
def training_step(self, batch, batch_idx):
|
||||
inputs, target = batch
|
||||
output = self.model(inputs, target)
|
||||
loss = torch.nn.functional.nll_loss(output, target.view(-1))
|
||||
|
||||
# logs metrics for each training_step,
|
||||
# and the average across the epoch, to the progress bar and logger
|
||||
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
return loss
|
||||
# logs metrics for each training_step,
|
||||
# and the average across the epoch, to the progress bar and logger
|
||||
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
return loss
|
||||
|
||||
The :meth:`~lightning.pytorch.core.module.LightningModule.log` method automatically reduces the
|
||||
requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood:
|
||||
|
@ -230,25 +234,25 @@ override the :meth:`~lightning.pytorch.LightningModule.on_train_epoch_end` metho
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.training_step_outputs = []
|
||||
class LightningTransformer(pl.LightningModule):
|
||||
def __init__(self, vocab_size):
|
||||
super().__init__()
|
||||
self.model = Transformer(vocab_size=vocab_size)
|
||||
self.training_step_outputs = []
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
inputs, target = batch
|
||||
output = self.model(inputs, target)
|
||||
loss = torch.nn.functional.nll_loss(output, target.view(-1))
|
||||
preds = ...
|
||||
self.training_step_outputs.append(preds)
|
||||
return loss
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
preds = ...
|
||||
self.training_step_outputs.append(preds)
|
||||
return loss
|
||||
|
||||
|
||||
def on_train_epoch_end(self):
|
||||
all_preds = torch.stack(self.training_step_outputs)
|
||||
# do something with all preds
|
||||
...
|
||||
self.training_step_outputs.clear() # free memory
|
||||
def on_train_epoch_end(self):
|
||||
all_preds = torch.stack(self.training_step_outputs)
|
||||
# do something with all preds
|
||||
...
|
||||
self.training_step_outputs.clear() # free memory
|
||||
|
||||
|
||||
------------------
|
||||
|
@ -264,10 +268,10 @@ To activate the validation loop while training, override the :meth:`~lightning.p
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
class LightningTransformer(pl.LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
inputs, target = batch
|
||||
output = self.model(inputs, target)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log("val_loss", loss)
|
||||
|
||||
|
@ -300,8 +304,8 @@ and calling :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
model = Model()
|
||||
trainer = Trainer()
|
||||
model = LightningTransformer(vocab_size=dataset.vocab_size)
|
||||
trainer = pl.Trainer()
|
||||
trainer.validate(model)
|
||||
|
||||
.. note::
|
||||
|
@ -322,25 +326,26 @@ Note that this method is called before :meth:`~lightning.pytorch.LightningModule
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.validation_step_outputs = []
|
||||
class LightningTransformer(pl.LightningModule):
|
||||
def __init__(self, vocab_size):
|
||||
super().__init__()
|
||||
self.model = Transformer(vocab_size=vocab_size)
|
||||
self.validation_step_outputs = []
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
inputs, target = batch
|
||||
output = self.model(inputs, target)
|
||||
loss = torch.nn.functional.nll_loss(output, target.view(-1))
|
||||
pred = ...
|
||||
self.validation_step_outputs.append(pred)
|
||||
return pred
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
pred = ...
|
||||
self.validation_step_outputs.append(pred)
|
||||
return pred
|
||||
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
all_preds = torch.stack(self.validation_step_outputs)
|
||||
# do something with all preds
|
||||
...
|
||||
self.validation_step_outputs.clear() # free memory
|
||||
def on_validation_epoch_end(self):
|
||||
all_preds = torch.stack(self.validation_step_outputs)
|
||||
# do something with all preds
|
||||
...
|
||||
self.validation_step_outputs.clear() # free memory
|
||||
|
||||
----------------
|
||||
|
||||
|
@ -358,9 +363,10 @@ The only difference is that the test loop is only called when :meth:`~lightning.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
model = Model()
|
||||
trainer = Trainer()
|
||||
trainer.fit(model)
|
||||
model = LightningTransformer(vocab_size=dataset.vocab_size)
|
||||
dataloader = DataLoader(dataset)
|
||||
trainer = pl.Trainer()
|
||||
trainer.fit(model=model, train_dataloaders=dataloader)
|
||||
|
||||
# automatically loads the best weights for you
|
||||
trainer.test(model)
|
||||
|
@ -370,17 +376,23 @@ There are two ways to call ``test()``:
|
|||
.. code-block:: python
|
||||
|
||||
# call after training
|
||||
trainer = Trainer()
|
||||
trainer.fit(model)
|
||||
trainer = pl.Trainer()
|
||||
trainer.fit(model=model, train_dataloaders=dataloader)
|
||||
|
||||
# automatically auto-loads the best weights from the previous run
|
||||
trainer.test(dataloaders=test_dataloader)
|
||||
trainer.test(dataloaders=test_dataloaders)
|
||||
|
||||
# or call with pretrained model
|
||||
model = MyLightningModule.load_from_checkpoint(PATH)
|
||||
trainer = Trainer()
|
||||
model = LightningTransformer.load_from_checkpoint(PATH)
|
||||
dataset = WikiText2()
|
||||
test_dataloader = DataLoader(dataset)
|
||||
trainer = pl.Trainer()
|
||||
trainer.test(model, dataloaders=test_dataloader)
|
||||
|
||||
.. note::
|
||||
`WikiText2` is used in a manner that does not create a train, test, val split. This is done for illustrative purposes only.
|
||||
A proper split can be created in :meth:`lightning.pytorch.core.LightningModule.setup` or :meth:`lightning.pytorch.core.LightningDataModule.setup`.
|
||||
|
||||
.. note::
|
||||
|
||||
It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once.
|
||||
|
@ -403,24 +415,18 @@ By default, the :meth:`~lightning.pytorch.core.module.LightningModule.predict_st
|
|||
:meth:`~lightning.pytorch.core.module.LightningModule.forward` method. In order to customize this behaviour,
|
||||
simply override the :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` method.
|
||||
|
||||
For the example let's override ``predict_step`` and try out `Monte Carlo Dropout <https://arxiv.org/pdf/1506.02142.pdf>`_:
|
||||
For the example let's override ``predict_step``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class LitMCdropoutModel(pl.LightningModule):
|
||||
def __init__(self, model, mc_iteration):
|
||||
class LightningTransformer(pl.LightningModule):
|
||||
def __init__(self, vocab_size):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.dropout = nn.Dropout()
|
||||
self.mc_iteration = mc_iteration
|
||||
self.model = Transformer(vocab_size=vocab_size)
|
||||
|
||||
def predict_step(self, batch, batch_idx):
|
||||
# enable Monte Carlo Dropout
|
||||
self.dropout.train()
|
||||
|
||||
# take average of `self.mc_iteration` iterations
|
||||
pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
|
||||
return pred
|
||||
def predict_step(self, batch):
|
||||
inputs, target = batch
|
||||
return self.model(inputs, target)
|
||||
|
||||
Under the hood, Lightning does the following (pseudocode):
|
||||
|
||||
|
@ -440,15 +446,17 @@ There are two ways to call ``predict()``:
|
|||
.. code-block:: python
|
||||
|
||||
# call after training
|
||||
trainer = Trainer()
|
||||
trainer.fit(model)
|
||||
trainer = pl.Trainer()
|
||||
trainer.fit(model=model, train_dataloaders=dataloader)
|
||||
|
||||
# automatically auto-loads the best weights from the previous run
|
||||
predictions = trainer.predict(dataloaders=predict_dataloader)
|
||||
|
||||
# or call with pretrained model
|
||||
model = MyLightningModule.load_from_checkpoint(PATH)
|
||||
trainer = Trainer()
|
||||
model = LightningTransformer.load_from_checkpoint(PATH)
|
||||
dataset = pl.demos.WikiText2()
|
||||
test_dataloader = DataLoader(dataset)
|
||||
trainer = pl.Trainer()
|
||||
predictions = trainer.predict(model, dataloaders=test_dataloader)
|
||||
|
||||
Inference in Research
|
||||
|
@ -460,15 +468,31 @@ If you want to perform inference with the system, you can add a ``forward`` meth
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
class Autoencoder(pl.LightningModule):
|
||||
def forward(self, x):
|
||||
return self.decoder(x)
|
||||
class LightningTransformer(pl.LightningModule):
|
||||
def __init__(self, vocab_size):
|
||||
super().__init__()
|
||||
self.model = Transformer(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, batch):
|
||||
inputs, target = batch
|
||||
return self.model(inputs, target)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
inputs, target = batch
|
||||
output = self.model(inputs, target)
|
||||
loss = torch.nn.functional.nll_loss(output, target.view(-1))
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.SGD(self.model.parameters(), lr=0.1)
|
||||
|
||||
|
||||
model = Autoencoder()
|
||||
model = LightningTransformer(vocab_size=dataset.vocab_size)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
reconstruction = model(embedding)
|
||||
batch = dataloader.dataset[0]
|
||||
pred = model(batch)
|
||||
|
||||
The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure,
|
||||
such as text generation:
|
||||
|
@ -618,7 +642,7 @@ checkpoint, which simplifies model re-instantiation after training.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
class LitMNIST(LightningModule):
|
||||
class LitMNIST(pl.LightningModule):
|
||||
def __init__(self, layer_1_dim=128, learning_rate=1e-2):
|
||||
super().__init__()
|
||||
# call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
|
||||
|
@ -642,7 +666,7 @@ parameters should be provided back when reloading the LightningModule. In this c
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
class LitMNIST(LightningModule):
|
||||
class LitMNIST(pl.LightningModule):
|
||||
def __init__(self, loss_fx, generator_network, layer_1_dim=128):
|
||||
super().__init__()
|
||||
self.layer_1_dim = layer_1_dim
|
||||
|
|
Loading…
Reference in New Issue