ThroughputMonitor Trainer callback fixes (#19027)

This commit is contained in:
Carlos Mocholí 2023-11-21 02:40:13 -05:00 committed by GitHub
parent dd206a20d8
commit 6cbe9ceb56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -193,14 +193,14 @@ class ThroughputMonitor(Callback):
self._update(trainer, pl_module, batch, iter_num)
self._compute(trainer, iter_num)
@rank_zero_only
def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
return
# add the validation time to the training time before continuing to avoid sinking the training throughput
time_between_train_and_val = (
self._t0s[RunningStage.VALIDATING] - self._throughputs[RunningStage.TRAINING]._time[-1]
)
val_time = self._throughputs[RunningStage.VALIDATING]._time[-1]
training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time)
time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished
val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time
@rank_zero_only