Enable loss-parallel in example (#19882)
This commit is contained in:
parent
82e6e61bea
commit
d76feef0d6
|
@ -35,7 +35,13 @@ def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
|
|||
# Parallelize the first embedding and the last linear out projection
|
||||
plan = {
|
||||
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
|
||||
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
|
||||
"output": ColwiseParallel(
|
||||
input_layouts=Shard(1),
|
||||
# Optional: Shard the output along the class dimension to compute the loss in parallel.
|
||||
# See `loss_parallel` in `train.py`
|
||||
output_layouts=Shard(-1),
|
||||
use_local_output=False,
|
||||
),
|
||||
"norm": SequenceParallel(),
|
||||
"layers.0": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), None),
|
||||
|
|
|
@ -57,8 +57,8 @@ def train():
|
|||
|
||||
with loss_parallel():
|
||||
loss = F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
|
||||
fabric.backward(loss)
|
||||
|
||||
fabric.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
fabric.print(f"Iteration {i} complete")
|
||||
|
|
|
@ -35,7 +35,13 @@ def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
|
|||
# Parallelize the first embedding and the last linear out projection
|
||||
plan = {
|
||||
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
|
||||
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
|
||||
"output": ColwiseParallel(
|
||||
input_layouts=Shard(1),
|
||||
# Optional: Shard the output along the class dimension to compute the loss in parallel.
|
||||
# See `loss_parallel` in `train.py`
|
||||
output_layouts=Shard(-1),
|
||||
use_local_output=False,
|
||||
),
|
||||
"norm": SequenceParallel(),
|
||||
"layers.0": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), None),
|
||||
|
|
|
@ -27,9 +27,14 @@ class Llama2(L.LightningModule):
|
|||
inputs = batch[:, :-1]
|
||||
labels = batch[:, 1:]
|
||||
output = self.model(inputs)
|
||||
# Optional: Parallelize loss computation across class dimension (see parallelism.py)
|
||||
with loss_parallel():
|
||||
return F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
|
||||
|
||||
def backward(self, *args, **kwargs):
|
||||
with loss_parallel():
|
||||
super().backward(*args, **kwargs)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.AdamW(self.model.parameters(), lr=3e-3, foreach=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue