improve early stopping verbose logging (#6811)

This commit is contained in:
Adrian Wälchli 2021-05-03 22:20:48 +02:00 committed by GitHub
parent 393b252ef0
commit bf1394a472
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 3 deletions

View File

@ -145,6 +145,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added warning when missing `Callback` and using `resume_from_checkpoint` ([#7254](https://github.com/PyTorchLightning/pytorch-lightning/pull/7254))
- Improved verbose logging for `EarlyStopping` callback ([#6811](https://github.com/PyTorchLightning/pytorch-lightning/pull/6811))
### Changed

View File

@ -24,6 +24,7 @@ from typing import Any, Callable, Dict, Optional, Tuple
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -196,8 +197,8 @@ class EarlyStopping(Callback):
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
if reason:
log.info(f"[{trainer.global_rank}] {reason}")
if reason and self.verbose:
self._log_info(trainer, reason)
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
should_stop = False
@ -224,6 +225,7 @@ class EarlyStopping(Callback):
)
elif self.monitor_op(current - self.min_delta, self.best_score):
should_stop = False
reason = self._improvement_message(current)
self.best_score = current
self.wait_count = 0
else:
@ -231,8 +233,26 @@ class EarlyStopping(Callback):
if self.wait_count >= self.patience:
should_stop = True
reason = (
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs."
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
)
return should_stop, reason
def _improvement_message(self, current: torch.Tensor) -> str:
""" Formats a log message that informs the user about an improvement in the monitored score. """
if torch.isfinite(self.best_score):
msg = (
f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >="
f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
)
else:
msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
return msg
@staticmethod
def _log_info(trainer: Optional["pl.Trainer"], message: str) -> None:
if trainer is not None and trainer.world_size > 1:
log.info(f"[rank: {trainer.global_rank}] {message}")
else:
log.info(message)