From 7eab379da2fdca542849ed4ad313d0851c2271e3 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 4 Jan 2022 22:52:32 +0530 Subject: [PATCH] Raise a warning if evaulation is triggered with best ckpt in case of multiple checkpoint callbacks (#11274) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/trainer.py | 12 +++++++++--- tests/trainer/test_trainer.py | 29 +++++++++++++++++++--------- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96087fad69..19111ecb3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254)) +- Raised `UserWarning` if evaluation is triggered with `best` ckpt and trainer is configured with multiple checkpoint callbacks ([#11274](https://github.com/PyTorchLightning/pytorch-lightning/pull/11274)) + + - `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5a6f67c9ea..d89ad75411 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1381,16 +1381,22 @@ class Trainer( ckpt_path = "best" if ckpt_path == "best": - # if user requests the best checkpoint but we don't have it, error + if len(self.checkpoint_callbacks) > 1: + rank_zero_warn( + f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`' + " callbacks. It will use the best checkpoint path from first checkpoint callback." + ) + if not self.checkpoint_callback: raise MisconfigurationException( f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.' ) + if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: raise MisconfigurationException( - f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do" - f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting." + f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.' + f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`" ) raise MisconfigurationException( f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 52bd2305d7..281afae7c3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -774,6 +774,21 @@ def test_tested_checkpoint_path_best(tmpdir, enable_checkpointing, fn): trainer_fn(model, ckpt_path="best") +def test_best_ckpt_evaluate_raises_warning_with_multiple_ckpt_callbacks(): + """Test that a warning is raised if best ckpt callback is used for evaluation configured with multiple + checkpoints.""" + + ckpt_callback1 = ModelCheckpoint() + ckpt_callback1.best_model_path = "foo_best_model.ckpt" + ckpt_callback2 = ModelCheckpoint() + ckpt_callback2.best_model_path = "bar_best_model.ckpt" + trainer = Trainer(callbacks=[ckpt_callback1, ckpt_callback2]) + trainer.state.fn = TrainerFn.TESTING + + with pytest.warns(UserWarning, match="best checkpoint path from first checkpoint callback"): + trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True) + + def test_disabled_training(tmpdir): """Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`.""" @@ -1799,15 +1814,11 @@ def test_trainer_attach_data_pipeline_to_model(tmpdir): trainer.fit(model, datamodule=dm) -def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - model = BoringModel() - trainer.fit(model) - - with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"): - trainer.validate() - with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"): - trainer.test() +def test_exception_when_testing_or_validating_with_fast_dev_run(): + trainer = Trainer(fast_dev_run=True) + trainer.state.fn = TrainerFn.TESTING + with pytest.raises(MisconfigurationException, match=r"with `fast_dev_run=True`. .* pass an exact checkpoint path"): + trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True) class TrainerStagesModel(BoringModel):