diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 063ef8c33d..04b8ea33b9 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split. .. code-block:: python import torch + import torch.nn as nn + import torch.nn.functional as F import torch.optim as optim - import pytorch_lightning as pl - from pytorch_lightning import LightningModule + from torch.utils.data import Dataset, DataLoader - class LitModel(LightningModule): + import lightning as L + + + class AverageDataset(Dataset): + def __init__(self, dataset_len=300, sequence_len=100): + self.dataset_len = dataset_len + self.sequence_len = sequence_len + self.input_seq = torch.randn(dataset_len, sequence_len, 10) + top, bottom = self.input_seq.chunk(2, -1) + self.output_seq = top + bottom.roll(shifts=1, dims=-1) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, item): + return self.input_seq[item], self.output_seq[item] + + + class LitModel(L.LightningModule): def __init__(self): super().__init__() + self.batch_size = 10 + self.in_features = 10 + self.out_features = 5 + self.hidden_dim = 20 + # 1. Switch to manual optimization self.automatic_optimization = False - self.truncated_bptt_steps = 10 - self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN + + self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True) + self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + + def forward(self, x, hs): + seq, hs = self.rnn(x, hs) + return self.linear_out(seq), hs # 2. Remove the `hiddens` argument def training_step(self, batch, batch_idx): - # 3. Split the batch in chunks along the time dimension - split_batches = split_batch(batch, self.truncated_bptt_steps) + x, y = batch + split_x, split_y = [ + x.tensor_split(self.truncated_bptt_steps, dim=1), + y.tensor_split(self.truncated_bptt_steps, dim=1) + ] - batch_size = 10 - hidden_dim = 20 - hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) - for split_batch in range(split_batches): - # 4. Perform the optimization in a loop - loss, hiddens = self.my_rnn(split_batch, hiddens) - self.backward(loss) - self.optimizer.step() - self.optimizer.zero_grad() + hiddens = None + optimizer = self.optimizers() + losses = [] + + # 4. Perform the optimization in a loop + for x, y in zip(split_x, split_y): + y_pred, hiddens = self(x, hiddens) + loss = F.mse_loss(y_pred, y) + + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() # 5. "Truncate" - hiddens = hiddens.detach() + hiddens = [h.detach() for h in hiddens] + losses.append(loss.detach()) + + avg_loss = sum(losses) / len(losses) + self.log("train_loss", avg_loss, prog_bar=True) # 6. Remove the return of `hiddens` # Returning loss in manual optimization is not needed return None def configure_optimizers(self): - return optim.Adam(self.my_rnn.parameters(), lr=0.001) + return optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=self.batch_size) + if __name__ == "__main__": model = LitModel() - trainer = pl.Trainer(max_epochs=5) - trainer.fit(model, train_dataloader) # Define your own dataloader + trainer = L.Trainer(max_epochs=5) + trainer.fit(model) diff --git a/tests/tests_pytorch/helpers/advanced_models.py b/tests/tests_pytorch/helpers/advanced_models.py index 4fecf51601..ade21004dc 100644 --- a/tests/tests_pytorch/helpers/advanced_models.py +++ b/tests/tests_pytorch/helpers/advanced_models.py @@ -219,3 +219,54 @@ class ParityModuleMNIST(LightningModule): def train_dataloader(self): return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1) + + +class TBPTTModule(LightningModule): + def __init__(self): + super().__init__() + + self.batch_size = 10 + self.in_features = 10 + self.out_features = 5 + self.hidden_dim = 20 + + self.automatic_optimization = False + self.truncated_bptt_steps = 10 + + self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True) + self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + + def forward(self, x, hs): + seq, hs = self.rnn(x, hs) + return self.linear_out(seq), hs + + def training_step(self, batch, batch_idx): + x, y = batch + split_x, split_y = [ + x.tensor_split(self.truncated_bptt_steps, dim=1), + y.tensor_split(self.truncated_bptt_steps, dim=1), + ] + + hiddens = None + optimizer = self.optimizers() + losses = [] + + for x, y in zip(split_x, split_y): + y_pred, hiddens = self(x, hiddens) + loss = F.mse_loss(y_pred, y) + + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + + # "Truncate" + hiddens = [h.detach() for h in hiddens] + losses.append(loss.detach()) + + return + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=self.batch_size) diff --git a/tests/tests_pytorch/helpers/test_models.py b/tests/tests_pytorch/helpers/test_models.py index 7e44f79413..cca2fbdc2e 100644 --- a/tests/tests_pytorch/helpers/test_models.py +++ b/tests/tests_pytorch/helpers/test_models.py @@ -17,7 +17,7 @@ import pytest from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN +from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel @@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class): model.to_torchscript() if data_class: model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample) + + +def test_tbptt(tmp_path): + model = TBPTTModule() + + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) + trainer.fit(model)