Fix callback call in Fabric Trainer example (#19986)
This commit is contained in:
parent
c1af4d0527
commit
394c42aaf6
|
@ -227,7 +227,7 @@ class MyCustomTrainer:
|
||||||
should_optim_step = self.global_step % self.grad_accum_steps == 0
|
should_optim_step = self.global_step % self.grad_accum_steps == 0
|
||||||
if should_optim_step:
|
if should_optim_step:
|
||||||
# currently only supports a single optimizer
|
# currently only supports a single optimizer
|
||||||
self.fabric.call("on_before_optimizer_step", optimizer, 0)
|
self.fabric.call("on_before_optimizer_step", optimizer)
|
||||||
|
|
||||||
# optimizer step runs train step internally through closure
|
# optimizer step runs train step internally through closure
|
||||||
optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
|
optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
|
||||||
|
|
Loading…
Reference in New Issue