From ae9803137a2e5a342c650fbf817a46fcf7620b0f Mon Sep 17 00:00:00 2001 From: Jinyoung Lim Date: Sat, 23 Jul 2022 05:07:47 -0700 Subject: [PATCH] Add logging messages to notify when `FitLoop` stopping conditions are met (#9749) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/pytorch_lightning/CHANGELOG.md | 3 ++ src/pytorch_lightning/loops/fit_loop.py | 29 ++++++++---- .../tests_pytorch/loops/test_training_loop.py | 47 +++++++++++++++++++ tests/tests_pytorch/trainer/test_trainer.py | 2 +- 4 files changed, 70 insertions(+), 11 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 7286d851c7..af53c9b063 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330)) +- Added logging messages to notify when `FitLoop` stopping conditions are met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749)) + + - Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224) diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 8b54579a6b..f4f7735f4b 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -33,7 +33,7 @@ from pytorch_lightning.utilities.fetching import ( InterBatchParallelDataFetcher, ) from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature log = logging.getLogger(__name__) @@ -150,31 +150,40 @@ class FitLoop(Loop[None]): @property def done(self) -> bool: """Evaluates when to leave the loop.""" + if self.trainer.num_training_batches == 0: + rank_zero_info("`Trainer.fit` stopped: No training batches.") + return True + # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps) + if stop_steps: + rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.") + return True + # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) if stop_epochs: # in case they are not equal, override so `trainer.current_epoch` has the expected value self.epoch_progress.current.completed = self.epoch_progress.current.processed + rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.") + return True - should_stop = False if self.trainer.should_stop: # early stopping met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: - should_stop = True + self.trainer.should_stop = True + rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.") + return True else: - log.info( - "Trainer was signaled to stop but required minimum epochs" - f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" - " not been met. Training will continue..." + rank_zero_info( + f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or" + f" `min_steps={self.min_steps!r}` has not been met. Training will continue..." ) - self.trainer.should_stop = should_stop - - return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0 + self.trainer.should_stop = False + return False @property def skip(self) -> bool: diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 9c5d24f96a..a9da6dcf2b 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from unittest.mock import Mock + import pytest import torch from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.loops import FitLoop def test_outputs_format(tmpdir): @@ -136,6 +140,49 @@ def test_should_stop_mid_epoch(tmpdir): assert model.validation_called_at == (0, 5) +def test_fit_loop_done_log_messages(caplog): + fit_loop = FitLoop() + trainer = Mock(spec=Trainer) + fit_loop.trainer = trainer + + trainer.should_stop = False + trainer.num_training_batches = 5 + assert not fit_loop.done + assert not caplog.messages + + trainer.num_training_batches = 0 + assert fit_loop.done + assert "No training batches" in caplog.text + caplog.clear() + trainer.num_training_batches = 5 + + epoch_loop = Mock() + epoch_loop.global_step = 10 + fit_loop.connect(epoch_loop=epoch_loop) + fit_loop.max_steps = 10 + assert fit_loop.done + assert "max_steps=10` reached" in caplog.text + caplog.clear() + fit_loop.max_steps = 20 + + fit_loop.epoch_progress.current.processed = 3 + fit_loop.max_epochs = 3 + trainer.should_stop = True + assert fit_loop.done + assert "max_epochs=3` reached" in caplog.text + caplog.clear() + fit_loop.max_epochs = 5 + + fit_loop.epoch_loop.min_steps = 0 + with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"): + assert fit_loop.done + assert "should_stop` was set" in caplog.text + + fit_loop.epoch_loop.min_steps = 100 + assert not fit_loop.done + assert "was signaled to stop but" in caplog.text + + def test_warning_valid_train_step_end(tmpdir): class ValidTrainStepEndModel(BoringModel): def training_step(self, batch, batch_idx): diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 3c82e6de84..ecc0ad724e 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -616,7 +616,7 @@ def test_trainer_min_steps_and_min_epochs_not_reached(tmpdir, caplog): with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): trainer.fit(model) - message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue" + message = f"min_epochs={min_epochs}` or `min_steps=None` has not been met. Training will continue" num_messages = sum(1 for record in caplog.records if message in record.message) assert num_messages == min_epochs - 2 assert model.training_step_invoked == min_epochs * 2