Fix TBPTT example (#20528)
* Fix TBPTT example * Make example self-contained * Update imports * Add test
This commit is contained in:
parent
ee7fa43d2a
commit
76f0c54a46
|
@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split.
|
|||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import LightningModule
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class LitModel(LightningModule):
|
||||
import lightning as L
|
||||
|
||||
|
||||
class AverageDataset(Dataset):
|
||||
def __init__(self, dataset_len=300, sequence_len=100):
|
||||
self.dataset_len = dataset_len
|
||||
self.sequence_len = sequence_len
|
||||
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
|
||||
top, bottom = self.input_seq.chunk(2, -1)
|
||||
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_len
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.input_seq[item], self.output_seq[item]
|
||||
|
||||
|
||||
class LitModel(L.LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.batch_size = 10
|
||||
self.in_features = 10
|
||||
self.out_features = 5
|
||||
self.hidden_dim = 20
|
||||
|
||||
# 1. Switch to manual optimization
|
||||
self.automatic_optimization = False
|
||||
|
||||
self.truncated_bptt_steps = 10
|
||||
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
|
||||
|
||||
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
|
||||
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
|
||||
|
||||
def forward(self, x, hs):
|
||||
seq, hs = self.rnn(x, hs)
|
||||
return self.linear_out(seq), hs
|
||||
|
||||
# 2. Remove the `hiddens` argument
|
||||
def training_step(self, batch, batch_idx):
|
||||
|
||||
# 3. Split the batch in chunks along the time dimension
|
||||
split_batches = split_batch(batch, self.truncated_bptt_steps)
|
||||
x, y = batch
|
||||
split_x, split_y = [
|
||||
x.tensor_split(self.truncated_bptt_steps, dim=1),
|
||||
y.tensor_split(self.truncated_bptt_steps, dim=1)
|
||||
]
|
||||
|
||||
hiddens = None
|
||||
optimizer = self.optimizers()
|
||||
losses = []
|
||||
|
||||
batch_size = 10
|
||||
hidden_dim = 20
|
||||
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
|
||||
for split_batch in range(split_batches):
|
||||
# 4. Perform the optimization in a loop
|
||||
loss, hiddens = self.my_rnn(split_batch, hiddens)
|
||||
self.backward(loss)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
for x, y in zip(split_x, split_y):
|
||||
y_pred, hiddens = self(x, hiddens)
|
||||
loss = F.mse_loss(y_pred, y)
|
||||
|
||||
optimizer.zero_grad()
|
||||
self.manual_backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
# 5. "Truncate"
|
||||
hiddens = hiddens.detach()
|
||||
hiddens = [h.detach() for h in hiddens]
|
||||
losses.append(loss.detach())
|
||||
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
self.log("train_loss", avg_loss, prog_bar=True)
|
||||
|
||||
# 6. Remove the return of `hiddens`
|
||||
# Returning loss in manual optimization is not needed
|
||||
return None
|
||||
|
||||
def configure_optimizers(self):
|
||||
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
|
||||
return optim.Adam(self.parameters(), lr=0.001)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(AverageDataset(), batch_size=self.batch_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = LitModel()
|
||||
trainer = pl.Trainer(max_epochs=5)
|
||||
trainer.fit(model, train_dataloader) # Define your own dataloader
|
||||
trainer = L.Trainer(max_epochs=5)
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -219,3 +219,54 @@ class ParityModuleMNIST(LightningModule):
|
|||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)
|
||||
|
||||
|
||||
class TBPTTModule(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.batch_size = 10
|
||||
self.in_features = 10
|
||||
self.out_features = 5
|
||||
self.hidden_dim = 20
|
||||
|
||||
self.automatic_optimization = False
|
||||
self.truncated_bptt_steps = 10
|
||||
|
||||
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
|
||||
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
|
||||
|
||||
def forward(self, x, hs):
|
||||
seq, hs = self.rnn(x, hs)
|
||||
return self.linear_out(seq), hs
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
split_x, split_y = [
|
||||
x.tensor_split(self.truncated_bptt_steps, dim=1),
|
||||
y.tensor_split(self.truncated_bptt_steps, dim=1),
|
||||
]
|
||||
|
||||
hiddens = None
|
||||
optimizer = self.optimizers()
|
||||
losses = []
|
||||
|
||||
for x, y in zip(split_x, split_y):
|
||||
y_pred, hiddens = self(x, hiddens)
|
||||
loss = F.mse_loss(y_pred, y)
|
||||
|
||||
optimizer.zero_grad()
|
||||
self.manual_backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
# "Truncate"
|
||||
hiddens = [h.detach() for h in hiddens]
|
||||
losses.append(loss.detach())
|
||||
|
||||
return
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.001)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(AverageDataset(), batch_size=self.batch_size)
|
||||
|
|
|
@ -17,7 +17,7 @@ import pytest
|
|||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
||||
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN
|
||||
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule
|
||||
from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel
|
||||
|
@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class):
|
|||
model.to_torchscript()
|
||||
if data_class:
|
||||
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)
|
||||
|
||||
|
||||
def test_tbptt(tmp_path):
|
||||
model = TBPTTModule()
|
||||
|
||||
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue