From bf1394a47263d828054ce3e286c622fdb82a64d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 3 May 2021 22:20:48 +0200 Subject: [PATCH] improve early stopping verbose logging (#6811) --- CHANGELOG.md | 3 +++ pytorch_lightning/callbacks/early_stopping.py | 26 ++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6573a11692..f8944e2e10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -145,6 +145,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added warning when missing `Callback` and using `resume_from_checkpoint` ([#7254](https://github.com/PyTorchLightning/pytorch-lightning/pull/7254)) +- Improved verbose logging for `EarlyStopping` callback ([#6811](https://github.com/PyTorchLightning/pytorch-lightning/pull/6811)) + + ### Changed diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 680bc55ed2..ba9ab188d4 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -24,6 +24,7 @@ from typing import Any, Callable, Dict, Optional, Tuple import numpy as np import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -196,8 +197,8 @@ class EarlyStopping(Callback): trainer.should_stop = trainer.should_stop or should_stop if should_stop: self.stopped_epoch = trainer.current_epoch - if reason: - log.info(f"[{trainer.global_rank}] {reason}") + if reason and self.verbose: + self._log_info(trainer, reason) def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: should_stop = False @@ -224,6 +225,7 @@ class EarlyStopping(Callback): ) elif self.monitor_op(current - self.min_delta, self.best_score): should_stop = False + reason = self._improvement_message(current) self.best_score = current self.wait_count = 0 else: @@ -231,8 +233,26 @@ class EarlyStopping(Callback): if self.wait_count >= self.patience: should_stop = True reason = ( - f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs." + f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records." f" Best score: {self.best_score:.3f}. Signaling Trainer to stop." ) return should_stop, reason + + def _improvement_message(self, current: torch.Tensor) -> str: + """ Formats a log message that informs the user about an improvement in the monitored score. """ + if torch.isfinite(self.best_score): + msg = ( + f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >=" + f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}" + ) + else: + msg = f"Metric {self.monitor} improved. New best score: {current:.3f}" + return msg + + @staticmethod + def _log_info(trainer: Optional["pl.Trainer"], message: str) -> None: + if trainer is not None and trainer.world_size > 1: + log.info(f"[rank: {trainer.global_rank}] {message}") + else: + log.info(message)