Fix callback call in Fabric Trainer example (#19986)

This commit is contained in:
liambsmith 2024-06-18 13:14:32 -04:00 committed by GitHub
parent c1af4d0527
commit 394c42aaf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -227,7 +227,7 @@ class MyCustomTrainer:
should_optim_step = self.global_step % self.grad_accum_steps == 0 should_optim_step = self.global_step % self.grad_accum_steps == 0
if should_optim_step: if should_optim_step:
# currently only supports a single optimizer # currently only supports a single optimizer
self.fabric.call("on_before_optimizer_step", optimizer, 0) self.fabric.call("on_before_optimizer_step", optimizer)
# optimizer step runs train step internally through closure # optimizer step runs train step internally through closure
optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx)) optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))