diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 680d4afa65..23abe454df 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -193,14 +193,14 @@ class ThroughputMonitor(Callback): self._update(trainer, pl_module, batch, iter_num) self._compute(trainer, iter_num) + @rank_zero_only def on_validation_end(self, trainer: "Trainer", *_: Any) -> None: if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING: return # add the validation time to the training time before continuing to avoid sinking the training throughput - time_between_train_and_val = ( - self._t0s[RunningStage.VALIDATING] - self._throughputs[RunningStage.TRAINING]._time[-1] - ) - val_time = self._throughputs[RunningStage.VALIDATING]._time[-1] + training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time) + time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished + val_time = sum(self._throughputs[RunningStage.VALIDATING]._time) self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time @rank_zero_only