79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
import lightning as L
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from data import RandomTokenDataset
|
|
from lightning.fabric.strategies import ModelParallelStrategy
|
|
from model import ModelArgs, Transformer
|
|
from parallelism import parallelize
|
|
from torch.distributed.tensor.parallel import loss_parallel
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def train():
|
|
strategy = ModelParallelStrategy(
|
|
# User-defined function that applies the desired parallelizations specific to the model
|
|
# (TP, FSDP2, activation checkpointing, ...)
|
|
parallelize_fn=parallelize,
|
|
# Define the size of the 2D parallelism
|
|
# Set to "auto" to apply TP intra-node and DP inter-node
|
|
data_parallel_size=2,
|
|
tensor_parallel_size=2,
|
|
)
|
|
|
|
fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
# Initialize the model
|
|
model_args = ModelArgs(vocab_size=32000)
|
|
with fabric.init_module(empty_init=True):
|
|
model = Transformer(model_args)
|
|
|
|
fabric.print(f"Number of model parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.1f} B")
|
|
|
|
# Set up model and optimizer
|
|
model = fabric.setup(model)
|
|
model.init_weights()
|
|
|
|
# Define the optimizer
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, foreach=True)
|
|
optimizer = fabric.setup_optimizers(optimizer)
|
|
|
|
# Define dataset/dataloader
|
|
dataset = RandomTokenDataset(vocab_size=model_args.vocab_size, seq_length=128)
|
|
dataloader = DataLoader(dataset, batch_size=8)
|
|
|
|
# Fabric configures the sampler automatically for you such that
|
|
# all batches in a tensor-parallel group are identical
|
|
dataloader = fabric.setup_dataloaders(dataloader)
|
|
|
|
# Simplified training loop
|
|
fabric.print("Starting training ...")
|
|
|
|
for i, batch in enumerate(dataloader):
|
|
inputs = batch[:, :-1]
|
|
labels = batch[:, 1:]
|
|
|
|
output = model(inputs)
|
|
|
|
with loss_parallel():
|
|
loss = F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
|
|
fabric.backward(loss)
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
fabric.print(f"Iteration {i} complete")
|
|
|
|
# See `fabric consolidate --help` if you need to convert the checkpoint to a single file
|
|
fabric.print("Saving a (distributed) checkpoint ...")
|
|
state = {"model": model, "optimizer": optimizer, "iteration": i}
|
|
fabric.save("checkpoint.pt", state)
|
|
|
|
fabric.print("Training successfully completed!")
|
|
fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
assert torch.cuda.device_count() >= 4, "This example requires at least 4 GPUs with 24 GB of memory each."
|
|
torch.set_float32_matmul_precision("high")
|
|
train()
|