From 127c52af747066b184faed215ab009ea6d77ad21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 6 Apr 2021 13:41:07 +0200 Subject: [PATCH] Fix EarlyStopping logic when min_epochs not met (#6705) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 10 ++++---- pytorch_lightning/trainer/trainer.py | 1 + tests/trainer/test_trainer.py | 35 ++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61fb9d7039..6c74cc13b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -209,6 +209,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828)) +- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705)) + + +- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730)) + + ## [1.2.6] - 2021-03-30 ### Changed @@ -241,12 +247,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) - Fixed duplicate logs appearing in console when using the python logging module ([#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) - Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) - - - Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741)) -- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730)) - ## [1.2.4] - 2021-03-16 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 27dcd6fe9a..36b591ba4e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -614,6 +614,7 @@ class Trainer( f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...' ) + self.should_stop = False # hook self.train_loop.on_train_end() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0c0488009d..bab26522ea 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import math import os import pickle @@ -528,6 +529,40 @@ def test_trainer_min_steps_and_epochs(tmpdir): assert trainer.global_step >= math.floor(num_train_samples * 1.5), "Model did not train for at least min_steps" +def test_trainer_min_steps_and_min_epochs_not_reached(tmpdir, caplog): + """ Test that min_epochs/min_steps in Trainer are enforced even if EarlyStopping is triggered. """ + + class TestModel(BoringModel): + training_step_invoked = 0 + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + output["loss"] = output["loss"] * 0.0 # force minimal loss to trigger early stopping + self.log("loss", output["loss"]) + self.training_step_invoked += 1 + assert not self.trainer.should_stop + return output + + model = TestModel() + early_stop = EarlyStopping(monitor="loss", patience=0) + min_epochs = 5 + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + min_epochs=min_epochs, + limit_val_batches=0, + limit_train_batches=2, + callbacks=[early_stop] + ) + with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): + trainer.fit(model) + + message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue" + num_messages = len([record.message for record in caplog.records if message in record.message]) + assert num_messages == min_epochs - 2 + assert model.training_step_invoked == min_epochs * 2 + + def test_trainer_max_steps_accumulate_batches(tmpdir): """Verify model trains according to specified max steps with grad accumulated batches""" model = BoringModel()