ThroughputMonitor Trainer callback fixes (#19027)
This commit is contained in:
parent
dd206a20d8
commit
6cbe9ceb56
|
@ -193,14 +193,14 @@ class ThroughputMonitor(Callback):
|
|||
self._update(trainer, pl_module, batch, iter_num)
|
||||
self._compute(trainer, iter_num)
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
|
||||
if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
|
||||
return
|
||||
# add the validation time to the training time before continuing to avoid sinking the training throughput
|
||||
time_between_train_and_val = (
|
||||
self._t0s[RunningStage.VALIDATING] - self._throughputs[RunningStage.TRAINING]._time[-1]
|
||||
)
|
||||
val_time = self._throughputs[RunningStage.VALIDATING]._time[-1]
|
||||
training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time)
|
||||
time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished
|
||||
val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
|
||||
self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time
|
||||
|
||||
@rank_zero_only
|
||||
|
|
Loading…
Reference in New Issue