Raise a warning if evaulation is triggered with best ckpt in case of multiple checkpoint callbacks (#11274)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Rohit Gupta 2022-01-04 22:52:32 +05:30 committed by GitHub
parent 650c710efa
commit 7eab379da2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 12 deletions

View File

@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
- Raised `UserWarning` if evaluation is triggered with `best` ckpt and trainer is configured with multiple checkpoint callbacks ([#11274](https://github.com/PyTorchLightning/pytorch-lightning/pull/11274))
- `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270))

View File

@ -1381,16 +1381,22 @@ class Trainer(
ckpt_path = "best"
if ckpt_path == "best":
# if user requests the best checkpoint but we don't have it, error
if len(self.checkpoint_callbacks) > 1:
rank_zero_warn(
f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`'
" callbacks. It will use the best checkpoint path from first checkpoint callback."
)
if not self.checkpoint_callback:
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
)
if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do"
f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting."
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
)
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'

View File

@ -774,6 +774,21 @@ def test_tested_checkpoint_path_best(tmpdir, enable_checkpointing, fn):
trainer_fn(model, ckpt_path="best")
def test_best_ckpt_evaluate_raises_warning_with_multiple_ckpt_callbacks():
"""Test that a warning is raised if best ckpt callback is used for evaluation configured with multiple
checkpoints."""
ckpt_callback1 = ModelCheckpoint()
ckpt_callback1.best_model_path = "foo_best_model.ckpt"
ckpt_callback2 = ModelCheckpoint()
ckpt_callback2.best_model_path = "bar_best_model.ckpt"
trainer = Trainer(callbacks=[ckpt_callback1, ckpt_callback2])
trainer.state.fn = TrainerFn.TESTING
with pytest.warns(UserWarning, match="best checkpoint path from first checkpoint callback"):
trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True)
def test_disabled_training(tmpdir):
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""
@ -1799,15 +1814,11 @@ def test_trainer_attach_data_pipeline_to_model(tmpdir):
trainer.fit(model, datamodule=dm)
def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
model = BoringModel()
trainer.fit(model)
with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
trainer.test()
def test_exception_when_testing_or_validating_with_fast_dev_run():
trainer = Trainer(fast_dev_run=True)
trainer.state.fn = TrainerFn.TESTING
with pytest.raises(MisconfigurationException, match=r"with `fast_dev_run=True`. .* pass an exact checkpoint path"):
trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True)
class TrainerStagesModel(BoringModel):