diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e0d8b51f3..daee5ae803 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added + +- Added a friendly error message when attempting to call `Trainer.save_checkpoint()` without a model attached ([#12772](https://github.com/PyTorchLightning/pytorch-lightning/pull/12772)) + + - Enabled `torch.inference_mode` for evaluation and prediction ([#12715](https://github.com/PyTorchLightning/pytorch-lightning/pull/12715)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 05e3112e60..41e47684fd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2364,6 +2364,11 @@ class Trainer( storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin """ + if self.model is None: + raise AttributeError( + "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call" + " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?" + ) self._checkpoint_connector.save_checkpoint(filepath, weights_only=weights_only, storage_options=storage_options) """ diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e13e41eda3..316384a20c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2133,3 +2133,10 @@ def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ trainer = Trainer(**trainer_kwargs) assert trainer.device_ids == expected_device_ids assert trainer.num_devices == len(expected_device_ids) + + +def test_trainer_save_checkpoint_no_model_attached(): + trainer = Trainer() + assert trainer.model is None + with pytest.raises(AttributeError, match="Saving a checkpoint is only possible if a model is attached"): + trainer.save_checkpoint("checkpoint.ckpt")