lightning/examples/pytorch/basics/transformer.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

62 lines
1.9 KiB
Python
Raw Normal View History

ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
import lightning as L
2023-04-06 18:32:23 +00:00
import torch
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
from torch.utils.data import DataLoader, random_split
2023-04-06 18:32:23 +00:00
class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size)
def training_step(self, batch, batch_idx):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("val_loss", loss, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("test_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)
def main():
L.seed_everything(42)
# Data
dataset = WikiText2()
# Split data in to train, val, test
n = len(dataset)
train_dataset, val_dataset, test_dataset = random_split(dataset, [n - 4000, 2000, 2000])
train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False)
# Model
model = LanguageModel(vocab_size=dataset.vocab_size)
# Trainer
trainer = L.Trainer(gradient_clip_val=0.25, max_epochs=20)
trainer.fit(model, train_dataloader, val_dataloader)
trainer.test(model, test_dataloader)
if __name__ == "__main__":
main()