fix best score on wrong device in EarlyStopping callback (#8295)
This commit is contained in:
parent
8fead58273
commit
1e1d1821d0
|
@ -200,7 +200,7 @@ class EarlyStopping(Callback):
|
|||
# when in dev debugging
|
||||
trainer.dev_debugger.track_early_stopping_history(self, current)
|
||||
|
||||
should_stop, reason = self._evalute_stopping_criteria(current, trainer)
|
||||
should_stop, reason = self._evalute_stopping_criteria(current)
|
||||
|
||||
# stop every ddp process if any world process decides to stop
|
||||
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
|
||||
|
@ -210,7 +210,7 @@ class EarlyStopping(Callback):
|
|||
if reason and self.verbose:
|
||||
self._log_info(trainer, reason)
|
||||
|
||||
def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]:
|
||||
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
|
||||
should_stop = False
|
||||
reason = None
|
||||
if self.check_finite and not torch.isfinite(current):
|
||||
|
@ -233,7 +233,7 @@ class EarlyStopping(Callback):
|
|||
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
|
||||
" Signaling Trainer to stop."
|
||||
)
|
||||
elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)):
|
||||
elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
|
||||
should_stop = False
|
||||
reason = self._improvement_message(current)
|
||||
self.best_score = current
|
||||
|
|
Loading…
Reference in New Issue