Enable loss-parallel in example (#19882)

This commit is contained in:
awaelchli 2024-05-20 13:19:38 +02:00 committed by GitHub
parent 82e6e61bea
commit d76feef0d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 3 deletions

View File

@ -35,7 +35,13 @@ def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
# Parallelize the first embedding and the last linear out projection
plan = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
"output": ColwiseParallel(
input_layouts=Shard(1),
# Optional: Shard the output along the class dimension to compute the loss in parallel.
# See `loss_parallel` in `train.py`
output_layouts=Shard(-1),
use_local_output=False,
),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),

View File

@ -57,8 +57,8 @@ def train():
with loss_parallel():
loss = F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
fabric.backward(loss)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
fabric.print(f"Iteration {i} complete")

View File

@ -35,7 +35,13 @@ def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
# Parallelize the first embedding and the last linear out projection
plan = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
"output": ColwiseParallel(
input_layouts=Shard(1),
# Optional: Shard the output along the class dimension to compute the loss in parallel.
# See `loss_parallel` in `train.py`
output_layouts=Shard(-1),
use_local_output=False,
),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),

View File

@ -27,9 +27,14 @@ class Llama2(L.LightningModule):
inputs = batch[:, :-1]
labels = batch[:, 1:]
output = self.model(inputs)
# Optional: Parallelize loss computation across class dimension (see parallelism.py)
with loss_parallel():
return F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
def backward(self, *args, **kwargs):
with loss_parallel():
super().backward(*args, **kwargs)
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=3e-3, foreach=True)