improve early stopping verbose logging (#6811)
This commit is contained in:
parent
393b252ef0
commit
bf1394a472
|
@ -145,6 +145,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added warning when missing `Callback` and using `resume_from_checkpoint` ([#7254](https://github.com/PyTorchLightning/pytorch-lightning/pull/7254))
|
||||
|
||||
|
||||
- Improved verbose logging for `EarlyStopping` callback ([#6811](https://github.com/PyTorchLightning/pytorch-lightning/pull/6811))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from typing import Any, Callable, Dict, Optional, Tuple
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -196,8 +197,8 @@ class EarlyStopping(Callback):
|
|||
trainer.should_stop = trainer.should_stop or should_stop
|
||||
if should_stop:
|
||||
self.stopped_epoch = trainer.current_epoch
|
||||
if reason:
|
||||
log.info(f"[{trainer.global_rank}] {reason}")
|
||||
if reason and self.verbose:
|
||||
self._log_info(trainer, reason)
|
||||
|
||||
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
|
||||
should_stop = False
|
||||
|
@ -224,6 +225,7 @@ class EarlyStopping(Callback):
|
|||
)
|
||||
elif self.monitor_op(current - self.min_delta, self.best_score):
|
||||
should_stop = False
|
||||
reason = self._improvement_message(current)
|
||||
self.best_score = current
|
||||
self.wait_count = 0
|
||||
else:
|
||||
|
@ -231,8 +233,26 @@ class EarlyStopping(Callback):
|
|||
if self.wait_count >= self.patience:
|
||||
should_stop = True
|
||||
reason = (
|
||||
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs."
|
||||
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
|
||||
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
|
||||
)
|
||||
|
||||
return should_stop, reason
|
||||
|
||||
def _improvement_message(self, current: torch.Tensor) -> str:
|
||||
""" Formats a log message that informs the user about an improvement in the monitored score. """
|
||||
if torch.isfinite(self.best_score):
|
||||
msg = (
|
||||
f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >="
|
||||
f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
|
||||
)
|
||||
else:
|
||||
msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
|
||||
return msg
|
||||
|
||||
@staticmethod
|
||||
def _log_info(trainer: Optional["pl.Trainer"], message: str) -> None:
|
||||
if trainer is not None and trainer.world_size > 1:
|
||||
log.info(f"[rank: {trainer.global_rank}] {message}")
|
||||
else:
|
||||
log.info(message)
|
||||
|
|
Loading…
Reference in New Issue