Fix TBPTT example (#20528)

* Fix TBPTT example

* Make example self-contained

* Update imports

* Add test
This commit is contained in:
Luca Antiga 2025-01-06 18:51:10 +01:00 committed by GitHub
parent ee7fa43d2a
commit 76f0c54a46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 122 additions and 21 deletions

View File

@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split.
.. code-block:: python .. code-block:: python
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import pytorch_lightning as pl from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule
class LitModel(LightningModule): import lightning as L
class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
self.dataset_len = dataset_len
self.sequence_len = sequence_len
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
top, bottom = self.input_seq.chunk(2, -1)
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
def __len__(self):
return self.dataset_len
def __getitem__(self, item):
return self.input_seq[item], self.output_seq[item]
class LitModel(L.LightningModule):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.batch_size = 10
self.in_features = 10
self.out_features = 5
self.hidden_dim = 20
# 1. Switch to manual optimization # 1. Switch to manual optimization
self.automatic_optimization = False self.automatic_optimization = False
self.truncated_bptt_steps = 10 self.truncated_bptt_steps = 10
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
def forward(self, x, hs):
seq, hs = self.rnn(x, hs)
return self.linear_out(seq), hs
# 2. Remove the `hiddens` argument # 2. Remove the `hiddens` argument
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# 3. Split the batch in chunks along the time dimension # 3. Split the batch in chunks along the time dimension
split_batches = split_batch(batch, self.truncated_bptt_steps) x, y = batch
split_x, split_y = [
x.tensor_split(self.truncated_bptt_steps, dim=1),
y.tensor_split(self.truncated_bptt_steps, dim=1)
]
batch_size = 10 hiddens = None
hidden_dim = 20 optimizer = self.optimizers()
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) losses = []
for split_batch in range(split_batches):
# 4. Perform the optimization in a loop # 4. Perform the optimization in a loop
loss, hiddens = self.my_rnn(split_batch, hiddens) for x, y in zip(split_x, split_y):
self.backward(loss) y_pred, hiddens = self(x, hiddens)
self.optimizer.step() loss = F.mse_loss(y_pred, y)
self.optimizer.zero_grad()
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
# 5. "Truncate" # 5. "Truncate"
hiddens = hiddens.detach() hiddens = [h.detach() for h in hiddens]
losses.append(loss.detach())
avg_loss = sum(losses) / len(losses)
self.log("train_loss", avg_loss, prog_bar=True)
# 6. Remove the return of `hiddens` # 6. Remove the return of `hiddens`
# Returning loss in manual optimization is not needed # Returning loss in manual optimization is not needed
return None return None
def configure_optimizers(self): def configure_optimizers(self):
return optim.Adam(self.my_rnn.parameters(), lr=0.001) return optim.Adam(self.parameters(), lr=0.001)
def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=self.batch_size)
if __name__ == "__main__": if __name__ == "__main__":
model = LitModel() model = LitModel()
trainer = pl.Trainer(max_epochs=5) trainer = L.Trainer(max_epochs=5)
trainer.fit(model, train_dataloader) # Define your own dataloader trainer.fit(model)

View File

@ -219,3 +219,54 @@ class ParityModuleMNIST(LightningModule):
def train_dataloader(self): def train_dataloader(self):
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1) return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)
class TBPTTModule(LightningModule):
def __init__(self):
super().__init__()
self.batch_size = 10
self.in_features = 10
self.out_features = 5
self.hidden_dim = 20
self.automatic_optimization = False
self.truncated_bptt_steps = 10
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
def forward(self, x, hs):
seq, hs = self.rnn(x, hs)
return self.linear_out(seq), hs
def training_step(self, batch, batch_idx):
x, y = batch
split_x, split_y = [
x.tensor_split(self.truncated_bptt_steps, dim=1),
y.tensor_split(self.truncated_bptt_steps, dim=1),
]
hiddens = None
optimizer = self.optimizers()
losses = []
for x, y in zip(split_x, split_y):
y_pred, hiddens = self(x, hiddens)
loss = F.mse_loss(y_pred, y)
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
# "Truncate"
hiddens = [h.detach() for h in hiddens]
losses.append(loss.detach())
return
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=self.batch_size)

View File

@ -17,7 +17,7 @@ import pytest
from lightning.pytorch import Trainer from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule
from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule
from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel
@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class):
model.to_torchscript() model.to_torchscript()
if data_class: if data_class:
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample) model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)
def test_tbptt(tmp_path):
model = TBPTTModule()
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
trainer.fit(model)