fix best score on wrong device in EarlyStopping callback (#8295)

This commit is contained in:
Adrian Wälchli 2021-07-06 10:59:33 +02:00 committed by GitHub
parent 8fead58273
commit 1e1d1821d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -200,7 +200,7 @@ class EarlyStopping(Callback):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)
should_stop, reason = self._evalute_stopping_criteria(current, trainer)
should_stop, reason = self._evalute_stopping_criteria(current)
# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
@ -210,7 +210,7 @@ class EarlyStopping(Callback):
if reason and self.verbose:
self._log_info(trainer, reason)
def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]:
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
should_stop = False
reason = None
if self.check_finite and not torch.isfinite(current):
@ -233,7 +233,7 @@ class EarlyStopping(Callback):
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
" Signaling Trainer to stop."
)
elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)):
elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
should_stop = False
reason = self._improvement_message(current)
self.best_score = current