diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c19a41a64..6b5ddef3e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -284,6 +284,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Pass the `stage` argument of `Callback.{setup,teardown}` as a keyword ([#7973](https://github.com/PyTorchLightning/pytorch-lightning/pull/7973)) +- Fixed move best score to device in EarlyStopping Callback ([#7959](https://github.com/PyTorchLightning/pytorch-lightning/pull/7959)) + + ## [1.3.6] - 2021-06-15 ### Fixed diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index e40bb7180c..b6bff43fd6 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -196,7 +196,7 @@ class EarlyStopping(Callback): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - should_stop, reason = self._evalute_stopping_criteria(current) + should_stop, reason = self._evalute_stopping_criteria(current, trainer) # stop every ddp process if any world process decides to stop should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) @@ -206,7 +206,7 @@ class EarlyStopping(Callback): if reason and self.verbose: self._log_info(trainer, reason) - def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: + def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): @@ -229,7 +229,7 @@ class EarlyStopping(Callback): f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}." " Signaling Trainer to stop." ) - elif self.monitor_op(current - self.min_delta, self.best_score): + elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)): should_stop = False reason = self._improvement_message(current) self.best_score = current