From 1e1d1821d070d2ed02c64a718bc81dbc6468c46b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 6 Jul 2021 10:59:33 +0200 Subject: [PATCH] fix best score on wrong device in EarlyStopping callback (#8295) --- pytorch_lightning/callbacks/early_stopping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 98675866b0..c28e5cec5b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -200,7 +200,7 @@ class EarlyStopping(Callback): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - should_stop, reason = self._evalute_stopping_criteria(current, trainer) + should_stop, reason = self._evalute_stopping_criteria(current) # stop every ddp process if any world process decides to stop should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) @@ -210,7 +210,7 @@ class EarlyStopping(Callback): if reason and self.verbose: self._log_info(trainer, reason) - def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]: + def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): @@ -233,7 +233,7 @@ class EarlyStopping(Callback): f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}." " Signaling Trainer to stop." ) - elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)): + elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)): should_stop = False reason = self._improvement_message(current) self.best_score = current