lightning/examples/fabric/fp8_distributed_transformer/train.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

101 lines
3.0 KiB
Python
Raw Normal View History

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.pytorch.demos import Transformer, WikiText2
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
from tqdm import tqdm
def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
float8_config = Float8LinearConfig(
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa
pad_inner_dim=True,
)
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# we skip the decoder because it typically vocabulary size
# is not divisible by 16 as required by float8
return fqn != "decoder"
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
for module in model.modules():
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
fully_shard(module, mesh=device_mesh)
fully_shard(model, mesh=device_mesh)
return torch.compile(model)
def train():
L.seed_everything(42)
batch_size = 8
micro_batch_size = 1
max_steps = 100
dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size)
with torch.device("meta"):
model = Transformer(
vocab_size=dataset.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)
strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model)
fabric = L.Fabric(precision="bf16-true", strategy=strategy)
fabric.launch()
model = fabric.setup(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = fabric.setup_optimizers(optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader)
steps = 0
for i, batch in iterable:
input, target = batch
is_accumulating = i % (batch_size // micro_batch_size) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
output = model(input, target)
loss = F.nll_loss(output, target.view(-1))
fabric.backward(loss)
if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
steps += 1
if fabric.is_global_zero:
iterable.set_postfix_str(f"train_loss={loss.item():.2f}")
if steps == max_steps:
break
fabric.print(torch.cuda.memory_summary())
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
train()