81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
import lightning as L
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from data import RandomTokenDataset
|
|
from lightning.pytorch.strategies import ModelParallelStrategy
|
|
from model import ModelArgs, Transformer
|
|
from parallelism import parallelize
|
|
from torch.distributed.tensor.parallel import loss_parallel
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
class Llama3(L.LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.model_args = ModelArgs(vocab_size=32000)
|
|
self.model = Transformer(self.model_args)
|
|
|
|
def configure_model(self):
|
|
# User-defined function that applies the desired parallelizations specific to the model
|
|
# (TP, FSDP2, activation checkpointing, ...)
|
|
parallelize(self.model, device_mesh=self.device_mesh)
|
|
|
|
def on_train_start(self) -> None:
|
|
self.model.init_weights()
|
|
|
|
def training_step(self, batch):
|
|
inputs = batch[:, :-1]
|
|
labels = batch[:, 1:]
|
|
output = self.model(inputs)
|
|
# Optional: Parallelize loss computation across class dimension (see parallelism.py)
|
|
with loss_parallel():
|
|
return F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
|
|
|
|
def backward(self, *args, **kwargs):
|
|
with loss_parallel():
|
|
super().backward(*args, **kwargs)
|
|
|
|
def configure_optimizers(self):
|
|
return torch.optim.AdamW(self.model.parameters(), lr=3e-3, foreach=True)
|
|
|
|
def train_dataloader(self):
|
|
dataset = RandomTokenDataset(vocab_size=self.model_args.vocab_size, seq_length=128)
|
|
# Trainer configures the sampler automatically for you such that
|
|
# all batches in a tensor-parallel group are identical
|
|
return DataLoader(dataset, batch_size=8, num_workers=4)
|
|
|
|
|
|
def train():
|
|
strategy = ModelParallelStrategy(
|
|
# Define the size of the 2D parallelism
|
|
# Set to "auto" to apply TP intra-node and DP inter-node
|
|
data_parallel_size=2,
|
|
tensor_parallel_size=2,
|
|
)
|
|
|
|
trainer = L.Trainer(
|
|
accelerator="cuda",
|
|
devices=4,
|
|
strategy=strategy,
|
|
limit_train_batches=10,
|
|
max_epochs=1,
|
|
)
|
|
|
|
# Initialize the model
|
|
with trainer.init_module(empty_init=True):
|
|
model = Llama3()
|
|
|
|
trainer.print(f"Number of model parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.1f} B")
|
|
trainer.print("Starting training ...")
|
|
|
|
trainer.fit(model)
|
|
|
|
trainer.print("Training successfully completed!")
|
|
trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
assert torch.cuda.device_count() >= 4, "This example requires at least 4 GPUs with 24 GB of memory each."
|
|
torch.set_float32_matmul_precision("high")
|
|
train()
|