lightning/examples/fabric/language_model/train.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

76 lines
2.5 KiB
Python
Raw Normal View History

ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
import lightning as L
2023-04-06 18:32:23 +00:00
import torch
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
from torch.utils.data import DataLoader, random_split
2023-04-06 18:32:23 +00:00
def main():
L.seed_everything(42)
fabric = L.Fabric()
# Data
dataset = WikiText2()
train_dataloader, val_dataloader, _ = get_dataloaders(dataset)
# Model
model = Transformer(vocab_size=dataset.vocab_size)
# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
train(fabric, model, optimizer, train_dataloader, val_dataloader)
def train(fabric, model, optimizer, train_dataloader, val_dataloader, max_epochs=20):
for epoch in range(max_epochs):
train_epoch(fabric, model, optimizer, train_dataloader, epoch)
val_loss = validate(fabric, model, val_dataloader)
fabric.print(f"val loss {val_loss.item():.4f}")
def train_epoch(fabric, model, optimizer, train_dataloader, epoch):
for batch_idx, batch in enumerate(train_dataloader):
input, target = batch
output = model(input, target)
loss = F.nll_loss(output, target.view(-1))
fabric.backward(loss)
fabric.clip_gradients(model, optimizer, clip_val=0.25)
optimizer.step()
optimizer.zero_grad()
if batch_idx % 200 == 0:
fabric.print(f"epoch: {epoch} - iteration: {batch_idx} - loss {loss.item():.4f}")
@torch.no_grad()
def validate(fabric, model, val_dataloader):
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(len(val_dataloader))
for k, batch in enumerate(val_dataloader):
input, target = batch
output = model(input, target)
loss = F.nll_loss(output, target.view(-1))
losses[k] = loss.item()
out = losses.mean()
model.train()
return out
def get_dataloaders(dataset):
n = len(dataset)
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset, [n - 4000, 2000, 2000], generator=generator)
2023-04-06 18:32:23 +00:00
train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False)
return train_dataloader, val_dataloader, test_dataloader
if __name__ == "__main__":
main()