Raise a warning if evaulation is triggered with best ckpt in case of multiple checkpoint callbacks (#11274)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
650c710efa
commit
7eab379da2
|
@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
|
- Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
|
||||||
|
|
||||||
|
|
||||||
|
- Raised `UserWarning` if evaluation is triggered with `best` ckpt and trainer is configured with multiple checkpoint callbacks ([#11274](https://github.com/PyTorchLightning/pytorch-lightning/pull/11274))
|
||||||
|
|
||||||
|
|
||||||
- `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270))
|
- `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1381,16 +1381,22 @@ class Trainer(
|
||||||
ckpt_path = "best"
|
ckpt_path = "best"
|
||||||
|
|
||||||
if ckpt_path == "best":
|
if ckpt_path == "best":
|
||||||
# if user requests the best checkpoint but we don't have it, error
|
if len(self.checkpoint_callbacks) > 1:
|
||||||
|
rank_zero_warn(
|
||||||
|
f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`'
|
||||||
|
" callbacks. It will use the best checkpoint path from first checkpoint callback."
|
||||||
|
)
|
||||||
|
|
||||||
if not self.checkpoint_callback:
|
if not self.checkpoint_callback:
|
||||||
raise MisconfigurationException(
|
raise MisconfigurationException(
|
||||||
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
|
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.checkpoint_callback.best_model_path:
|
if not self.checkpoint_callback.best_model_path:
|
||||||
if self.fast_dev_run:
|
if self.fast_dev_run:
|
||||||
raise MisconfigurationException(
|
raise MisconfigurationException(
|
||||||
f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do"
|
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
|
||||||
f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting."
|
f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
|
||||||
)
|
)
|
||||||
raise MisconfigurationException(
|
raise MisconfigurationException(
|
||||||
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
|
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
|
||||||
|
|
|
@ -774,6 +774,21 @@ def test_tested_checkpoint_path_best(tmpdir, enable_checkpointing, fn):
|
||||||
trainer_fn(model, ckpt_path="best")
|
trainer_fn(model, ckpt_path="best")
|
||||||
|
|
||||||
|
|
||||||
|
def test_best_ckpt_evaluate_raises_warning_with_multiple_ckpt_callbacks():
|
||||||
|
"""Test that a warning is raised if best ckpt callback is used for evaluation configured with multiple
|
||||||
|
checkpoints."""
|
||||||
|
|
||||||
|
ckpt_callback1 = ModelCheckpoint()
|
||||||
|
ckpt_callback1.best_model_path = "foo_best_model.ckpt"
|
||||||
|
ckpt_callback2 = ModelCheckpoint()
|
||||||
|
ckpt_callback2.best_model_path = "bar_best_model.ckpt"
|
||||||
|
trainer = Trainer(callbacks=[ckpt_callback1, ckpt_callback2])
|
||||||
|
trainer.state.fn = TrainerFn.TESTING
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning, match="best checkpoint path from first checkpoint callback"):
|
||||||
|
trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True)
|
||||||
|
|
||||||
|
|
||||||
def test_disabled_training(tmpdir):
|
def test_disabled_training(tmpdir):
|
||||||
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""
|
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""
|
||||||
|
|
||||||
|
@ -1799,15 +1814,11 @@ def test_trainer_attach_data_pipeline_to_model(tmpdir):
|
||||||
trainer.fit(model, datamodule=dm)
|
trainer.fit(model, datamodule=dm)
|
||||||
|
|
||||||
|
|
||||||
def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
|
def test_exception_when_testing_or_validating_with_fast_dev_run():
|
||||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
trainer = Trainer(fast_dev_run=True)
|
||||||
model = BoringModel()
|
trainer.state.fn = TrainerFn.TESTING
|
||||||
trainer.fit(model)
|
with pytest.raises(MisconfigurationException, match=r"with `fast_dev_run=True`. .* pass an exact checkpoint path"):
|
||||||
|
trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True)
|
||||||
with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"):
|
|
||||||
trainer.validate()
|
|
||||||
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
|
|
||||||
trainer.test()
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerStagesModel(BoringModel):
|
class TrainerStagesModel(BoringModel):
|
||||||
|
|
Loading…
Reference in New Issue