diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 98675866b0..c28e5cec5b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -200,7 +200,7 @@ class EarlyStopping(Callback): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - should_stop, reason = self._evalute_stopping_criteria(current, trainer) + should_stop, reason = self._evalute_stopping_criteria(current) # stop every ddp process if any world process decides to stop should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) @@ -210,7 +210,7 @@ class EarlyStopping(Callback): if reason and self.verbose: self._log_info(trainer, reason) - def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]: + def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): @@ -233,7 +233,7 @@ class EarlyStopping(Callback): f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}." " Signaling Trainer to stop." ) - elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)): + elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)): should_stop = False reason = self._improvement_message(current) self.best_score = current