Fix(Early Stopping): move best score to device (#7959)

This commit is contained in:
Kaushik B 2021-06-21 15:41:41 +05:30 committed by GitHub
parent 92a78d58c3
commit 2303f9ced8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -284,6 +284,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Pass the `stage` argument of `Callback.{setup,teardown}` as a keyword ([#7973](https://github.com/PyTorchLightning/pytorch-lightning/pull/7973))
- Fixed move best score to device in EarlyStopping Callback ([#7959](https://github.com/PyTorchLightning/pytorch-lightning/pull/7959))
## [1.3.6] - 2021-06-15
### Fixed

View File

@ -196,7 +196,7 @@ class EarlyStopping(Callback):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)
should_stop, reason = self._evalute_stopping_criteria(current)
should_stop, reason = self._evalute_stopping_criteria(current, trainer)
# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
@ -206,7 +206,7 @@ class EarlyStopping(Callback):
if reason and self.verbose:
self._log_info(trainer, reason)
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]:
should_stop = False
reason = None
if self.check_finite and not torch.isfinite(current):
@ -229,7 +229,7 @@ class EarlyStopping(Callback):
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
" Signaling Trainer to stop."
)
elif self.monitor_op(current - self.min_delta, self.best_score):
elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)):
should_stop = False
reason = self._improvement_message(current)
self.best_score = current