Fix TBPTT example (#20528)
* Fix TBPTT example * Make example self-contained * Update imports * Add test
This commit is contained in:
parent
ee7fa43d2a
commit
76f0c54a46
|
@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split.
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import pytorch_lightning as pl
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from pytorch_lightning import LightningModule
|
|
||||||
|
|
||||||
class LitModel(LightningModule):
|
import lightning as L
|
||||||
|
|
||||||
|
|
||||||
|
class AverageDataset(Dataset):
|
||||||
|
def __init__(self, dataset_len=300, sequence_len=100):
|
||||||
|
self.dataset_len = dataset_len
|
||||||
|
self.sequence_len = sequence_len
|
||||||
|
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
|
||||||
|
top, bottom = self.input_seq.chunk(2, -1)
|
||||||
|
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset_len
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.input_seq[item], self.output_seq[item]
|
||||||
|
|
||||||
|
|
||||||
|
class LitModel(L.LightningModule):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.batch_size = 10
|
||||||
|
self.in_features = 10
|
||||||
|
self.out_features = 5
|
||||||
|
self.hidden_dim = 20
|
||||||
|
|
||||||
# 1. Switch to manual optimization
|
# 1. Switch to manual optimization
|
||||||
self.automatic_optimization = False
|
self.automatic_optimization = False
|
||||||
|
|
||||||
self.truncated_bptt_steps = 10
|
self.truncated_bptt_steps = 10
|
||||||
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
|
|
||||||
|
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
|
||||||
|
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
|
||||||
|
|
||||||
|
def forward(self, x, hs):
|
||||||
|
seq, hs = self.rnn(x, hs)
|
||||||
|
return self.linear_out(seq), hs
|
||||||
|
|
||||||
# 2. Remove the `hiddens` argument
|
# 2. Remove the `hiddens` argument
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
|
|
||||||
# 3. Split the batch in chunks along the time dimension
|
# 3. Split the batch in chunks along the time dimension
|
||||||
split_batches = split_batch(batch, self.truncated_bptt_steps)
|
x, y = batch
|
||||||
|
split_x, split_y = [
|
||||||
|
x.tensor_split(self.truncated_bptt_steps, dim=1),
|
||||||
|
y.tensor_split(self.truncated_bptt_steps, dim=1)
|
||||||
|
]
|
||||||
|
|
||||||
batch_size = 10
|
hiddens = None
|
||||||
hidden_dim = 20
|
optimizer = self.optimizers()
|
||||||
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
|
losses = []
|
||||||
for split_batch in range(split_batches):
|
|
||||||
# 4. Perform the optimization in a loop
|
# 4. Perform the optimization in a loop
|
||||||
loss, hiddens = self.my_rnn(split_batch, hiddens)
|
for x, y in zip(split_x, split_y):
|
||||||
self.backward(loss)
|
y_pred, hiddens = self(x, hiddens)
|
||||||
self.optimizer.step()
|
loss = F.mse_loss(y_pred, y)
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
self.manual_backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
# 5. "Truncate"
|
# 5. "Truncate"
|
||||||
hiddens = hiddens.detach()
|
hiddens = [h.detach() for h in hiddens]
|
||||||
|
losses.append(loss.detach())
|
||||||
|
|
||||||
|
avg_loss = sum(losses) / len(losses)
|
||||||
|
self.log("train_loss", avg_loss, prog_bar=True)
|
||||||
|
|
||||||
# 6. Remove the return of `hiddens`
|
# 6. Remove the return of `hiddens`
|
||||||
# Returning loss in manual optimization is not needed
|
# Returning loss in manual optimization is not needed
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
|
return optim.Adam(self.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(AverageDataset(), batch_size=self.batch_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = LitModel()
|
model = LitModel()
|
||||||
trainer = pl.Trainer(max_epochs=5)
|
trainer = L.Trainer(max_epochs=5)
|
||||||
trainer.fit(model, train_dataloader) # Define your own dataloader
|
trainer.fit(model)
|
||||||
|
|
|
@ -219,3 +219,54 @@ class ParityModuleMNIST(LightningModule):
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)
|
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)
|
||||||
|
|
||||||
|
|
||||||
|
class TBPTTModule(LightningModule):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.batch_size = 10
|
||||||
|
self.in_features = 10
|
||||||
|
self.out_features = 5
|
||||||
|
self.hidden_dim = 20
|
||||||
|
|
||||||
|
self.automatic_optimization = False
|
||||||
|
self.truncated_bptt_steps = 10
|
||||||
|
|
||||||
|
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
|
||||||
|
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
|
||||||
|
|
||||||
|
def forward(self, x, hs):
|
||||||
|
seq, hs = self.rnn(x, hs)
|
||||||
|
return self.linear_out(seq), hs
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
split_x, split_y = [
|
||||||
|
x.tensor_split(self.truncated_bptt_steps, dim=1),
|
||||||
|
y.tensor_split(self.truncated_bptt_steps, dim=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
hiddens = None
|
||||||
|
optimizer = self.optimizers()
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for x, y in zip(split_x, split_y):
|
||||||
|
y_pred, hiddens = self(x, hiddens)
|
||||||
|
loss = F.mse_loss(y_pred, y)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
self.manual_backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# "Truncate"
|
||||||
|
hiddens = [h.detach() for h in hiddens]
|
||||||
|
losses.append(loss.detach())
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return torch.optim.Adam(self.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(AverageDataset(), batch_size=self.batch_size)
|
||||||
|
|
|
@ -17,7 +17,7 @@ import pytest
|
||||||
from lightning.pytorch import Trainer
|
from lightning.pytorch import Trainer
|
||||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||||
|
|
||||||
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN
|
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule
|
||||||
from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule
|
from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule
|
||||||
from tests_pytorch.helpers.runif import RunIf
|
from tests_pytorch.helpers.runif import RunIf
|
||||||
from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel
|
from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel
|
||||||
|
@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class):
|
||||||
model.to_torchscript()
|
model.to_torchscript()
|
||||||
if data_class:
|
if data_class:
|
||||||
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)
|
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tbptt(tmp_path):
|
||||||
|
model = TBPTTModule()
|
||||||
|
|
||||||
|
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
Loading…
Reference in New Issue