From d76feef0d6c4f46c1e01be49c13c965f4ce942ef Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 20 May 2024 13:19:38 +0200 Subject: [PATCH] Enable loss-parallel in example (#19882) --- examples/fabric/tensor_parallel/parallelism.py | 8 +++++++- examples/fabric/tensor_parallel/train.py | 2 +- examples/pytorch/tensor_parallel/parallelism.py | 8 +++++++- examples/pytorch/tensor_parallel/train.py | 5 +++++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/fabric/tensor_parallel/parallelism.py b/examples/fabric/tensor_parallel/parallelism.py index f6f38aa499..44d55c8da1 100644 --- a/examples/fabric/tensor_parallel/parallelism.py +++ b/examples/fabric/tensor_parallel/parallelism.py @@ -35,7 +35,13 @@ def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer: # Parallelize the first embedding and the last linear out projection plan = { "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), - "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()), + "output": ColwiseParallel( + input_layouts=Shard(1), + # Optional: Shard the output along the class dimension to compute the loss in parallel. + # See `loss_parallel` in `train.py` + output_layouts=Shard(-1), + use_local_output=False, + ), "norm": SequenceParallel(), "layers.0": PrepareModuleInput( input_layouts=(Replicate(), None), diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py index 2c3ab38198..ce48fe341f 100644 --- a/examples/fabric/tensor_parallel/train.py +++ b/examples/fabric/tensor_parallel/train.py @@ -57,8 +57,8 @@ def train(): with loss_parallel(): loss = F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1)) + fabric.backward(loss) - fabric.backward(loss) optimizer.step() optimizer.zero_grad() fabric.print(f"Iteration {i} complete") diff --git a/examples/pytorch/tensor_parallel/parallelism.py b/examples/pytorch/tensor_parallel/parallelism.py index f6f38aa499..44d55c8da1 100644 --- a/examples/pytorch/tensor_parallel/parallelism.py +++ b/examples/pytorch/tensor_parallel/parallelism.py @@ -35,7 +35,13 @@ def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer: # Parallelize the first embedding and the last linear out projection plan = { "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), - "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()), + "output": ColwiseParallel( + input_layouts=Shard(1), + # Optional: Shard the output along the class dimension to compute the loss in parallel. + # See `loss_parallel` in `train.py` + output_layouts=Shard(-1), + use_local_output=False, + ), "norm": SequenceParallel(), "layers.0": PrepareModuleInput( input_layouts=(Replicate(), None), diff --git a/examples/pytorch/tensor_parallel/train.py b/examples/pytorch/tensor_parallel/train.py index ad4220a3fc..6efbadf175 100644 --- a/examples/pytorch/tensor_parallel/train.py +++ b/examples/pytorch/tensor_parallel/train.py @@ -27,9 +27,14 @@ class Llama2(L.LightningModule): inputs = batch[:, :-1] labels = batch[:, 1:] output = self.model(inputs) + # Optional: Parallelize loss computation across class dimension (see parallelism.py) with loss_parallel(): return F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1)) + def backward(self, *args, **kwargs): + with loss_parallel(): + super().backward(*args, **kwargs) + def configure_optimizers(self): return torch.optim.AdamW(self.model.parameters(), lr=3e-3, foreach=True)