Fix(Early Stopping): move best score to device (#7959)
This commit is contained in:
parent
92a78d58c3
commit
2303f9ced8
|
@ -284,6 +284,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Pass the `stage` argument of `Callback.{setup,teardown}` as a keyword ([#7973](https://github.com/PyTorchLightning/pytorch-lightning/pull/7973))
|
- Pass the `stage` argument of `Callback.{setup,teardown}` as a keyword ([#7973](https://github.com/PyTorchLightning/pytorch-lightning/pull/7973))
|
||||||
|
|
||||||
|
|
||||||
|
- Fixed move best score to device in EarlyStopping Callback ([#7959](https://github.com/PyTorchLightning/pytorch-lightning/pull/7959))
|
||||||
|
|
||||||
|
|
||||||
## [1.3.6] - 2021-06-15
|
## [1.3.6] - 2021-06-15
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
|
@ -196,7 +196,7 @@ class EarlyStopping(Callback):
|
||||||
# when in dev debugging
|
# when in dev debugging
|
||||||
trainer.dev_debugger.track_early_stopping_history(self, current)
|
trainer.dev_debugger.track_early_stopping_history(self, current)
|
||||||
|
|
||||||
should_stop, reason = self._evalute_stopping_criteria(current)
|
should_stop, reason = self._evalute_stopping_criteria(current, trainer)
|
||||||
|
|
||||||
# stop every ddp process if any world process decides to stop
|
# stop every ddp process if any world process decides to stop
|
||||||
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
|
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
|
||||||
|
@ -206,7 +206,7 @@ class EarlyStopping(Callback):
|
||||||
if reason and self.verbose:
|
if reason and self.verbose:
|
||||||
self._log_info(trainer, reason)
|
self._log_info(trainer, reason)
|
||||||
|
|
||||||
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
|
def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]:
|
||||||
should_stop = False
|
should_stop = False
|
||||||
reason = None
|
reason = None
|
||||||
if self.check_finite and not torch.isfinite(current):
|
if self.check_finite and not torch.isfinite(current):
|
||||||
|
@ -229,7 +229,7 @@ class EarlyStopping(Callback):
|
||||||
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
|
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
|
||||||
" Signaling Trainer to stop."
|
" Signaling Trainer to stop."
|
||||||
)
|
)
|
||||||
elif self.monitor_op(current - self.min_delta, self.best_score):
|
elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)):
|
||||||
should_stop = False
|
should_stop = False
|
||||||
reason = self._improvement_message(current)
|
reason = self._improvement_message(current)
|
||||||
self.best_score = current
|
self.best_score = current
|
||||||
|
|
Loading…
Reference in New Issue